mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 15:26:40 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
206
client/ssh/server/command_execution.go
Normal file
206
client/ssh/server/command_execution.go
Normal file
@@ -0,0 +1,206 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handleCommand executes an SSH command with privilege validation
|
||||
func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, winCh <-chan ssh.Window) {
|
||||
hasPty := winCh != nil
|
||||
|
||||
commandType := "command"
|
||||
if hasPty {
|
||||
commandType = "Pty command"
|
||||
}
|
||||
|
||||
logger.Infof("executing %s: %s", commandType, safeLogCommand(session.Command()))
|
||||
|
||||
execCmd, cleanup, err := s.createCommand(privilegeResult, session, hasPty)
|
||||
if err != nil {
|
||||
logger.Errorf("%s creation failed: %v", commandType, err)
|
||||
|
||||
errorMsg := fmt.Sprintf("Cannot create %s - platform may not support user switching", commandType)
|
||||
if hasPty {
|
||||
errorMsg += " with Pty"
|
||||
}
|
||||
errorMsg += "\n"
|
||||
|
||||
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if !hasPty {
|
||||
if s.executeCommand(logger, session, execCmd, cleanup) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
defer cleanup()
|
||||
|
||||
ptyReq, _, _ := session.Pty()
|
||||
if s.executeCommandWithPty(logger, session, execCmd, privilegeResult, ptyReq, winCh) {
|
||||
logger.Debugf("%s execution completed", commandType)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh.Session, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
return nil, nil, errors.New("no user in privilege result")
|
||||
}
|
||||
|
||||
// If PTY requested but su doesn't support --pty, skip su and use executor
|
||||
// This ensures PTY functionality is provided (executor runs within our allocated PTY)
|
||||
if hasPty && !s.suSupportsPty {
|
||||
log.Debugf("PTY requested but su doesn't support --pty, using executor for PTY functionality")
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
// Try su first for system integration (PAM/audit) when privileged
|
||||
cmd, err := s.createSuCommand(session, localUser, hasPty)
|
||||
if err != nil || privilegeResult.UsedFallback {
|
||||
log.Debugf("su command failed, falling back to executor: %v", err)
|
||||
cmd, cleanup, err := s.createExecutorCommand(session, localUser, hasPty)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create command with privileges: %w", err)
|
||||
}
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
cmd.Env = s.prepareCommandEnv(localUser, session)
|
||||
return cmd, func() {}, nil
|
||||
}
|
||||
|
||||
// executeCommand executes the command and handles I/O and exit codes
|
||||
func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, cleanup func()) bool {
|
||||
defer cleanup()
|
||||
|
||||
s.setupProcessGroup(execCmd)
|
||||
|
||||
stdinPipe, err := execCmd.StdinPipe()
|
||||
if err != nil {
|
||||
logger.Errorf("create stdin pipe: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
execCmd.Stdout = session
|
||||
execCmd.Stderr = session.Stderr()
|
||||
|
||||
if execCmd.Dir != "" {
|
||||
if _, err := os.Stat(execCmd.Dir); err != nil {
|
||||
logger.Warnf("working directory does not exist: %s (%v)", execCmd.Dir, err)
|
||||
execCmd.Dir = "/"
|
||||
}
|
||||
}
|
||||
|
||||
if err := execCmd.Start(); err != nil {
|
||||
logger.Errorf("command start failed: %v", err)
|
||||
// no user message for exec failure, just exit
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
go s.handleCommandIO(logger, stdinPipe, session)
|
||||
return s.waitForCommandCleanup(logger, session, execCmd)
|
||||
}
|
||||
|
||||
// handleCommandIO manages stdin/stdout copying in a goroutine
|
||||
func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, session ssh.Session) {
|
||||
defer func() {
|
||||
if err := stdinPipe.Close(); err != nil {
|
||||
logger.Debugf("stdin pipe close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(stdinPipe, session); err != nil {
|
||||
logger.Debugf("stdin copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// waitForCommandCleanup waits for command completion with session disconnect handling
|
||||
func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool {
|
||||
ctx := session.Context()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- execCmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session cancelled, terminating command")
|
||||
s.killProcessGroup(execCmd)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
logger.Tracef("command terminated after session cancellation: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
logger.Warnf("command did not terminate within 5 seconds after session cancellation")
|
||||
}
|
||||
|
||||
if err := session.Exit(130); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
|
||||
case err := <-done:
|
||||
return s.handleCommandCompletion(logger, session, err)
|
||||
}
|
||||
}
|
||||
|
||||
// handleCommandCompletion handles command completion
|
||||
func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool {
|
||||
if err != nil {
|
||||
logger.Debugf("command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return false
|
||||
}
|
||||
|
||||
s.handleSessionExit(session, nil, logger)
|
||||
return true
|
||||
}
|
||||
|
||||
// handleSessionExit handles command errors and sets appropriate exit codes
|
||||
func (s *Server) handleSessionExit(session ssh.Session, err error, logger *log.Entry) {
|
||||
if err == nil {
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
if err := session.Exit(exitError.ExitCode()); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
} else {
|
||||
logger.Debugf("non-exit error in command execution: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
52
client/ssh/server/command_execution_js.go
Normal file
52
client/ssh/server/command_execution_js.go
Normal file
@@ -0,0 +1,52 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var errNotSupported = errors.New("SSH server command execution not supported on WASM/JS platform")
|
||||
|
||||
// createSuCommand is not supported on JS/WASM
|
||||
func (s *Server) createSuCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, error) {
|
||||
return nil, errNotSupported
|
||||
}
|
||||
|
||||
// createExecutorCommand is not supported on JS/WASM
|
||||
func (s *Server) createExecutorCommand(_ ssh.Session, _ *user.User, _ bool) (*exec.Cmd, func(), error) {
|
||||
return nil, nil, errNotSupported
|
||||
}
|
||||
|
||||
// prepareCommandEnv is not supported on JS/WASM
|
||||
func (s *Server) prepareCommandEnv(_ *user.User, _ ssh.Session) []string {
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupProcessGroup is not supported on JS/WASM
|
||||
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
||||
}
|
||||
|
||||
// killProcessGroup is not supported on JS/WASM
|
||||
func (s *Server) killProcessGroup(*exec.Cmd) {
|
||||
}
|
||||
|
||||
// detectSuPtySupport always returns false on JS/WASM
|
||||
func (s *Server) detectSuPtySupport(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// executeCommandWithPty is not supported on JS/WASM
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
logger.Errorf("PTY command execution not supported on JS/WASM")
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
329
client/ssh/server/command_execution_unix.go
Normal file
329
client/ssh/server/command_execution_unix.go
Normal file
@@ -0,0 +1,329 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/creack/pty"
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// ptyManager manages Pty file operations with thread safety
|
||||
type ptyManager struct {
|
||||
file *os.File
|
||||
mu sync.RWMutex
|
||||
closed bool
|
||||
closeErr error
|
||||
once sync.Once
|
||||
}
|
||||
|
||||
func newPtyManager(file *os.File) *ptyManager {
|
||||
return &ptyManager{file: file}
|
||||
}
|
||||
|
||||
func (pm *ptyManager) Close() error {
|
||||
pm.once.Do(func() {
|
||||
pm.mu.Lock()
|
||||
pm.closed = true
|
||||
pm.closeErr = pm.file.Close()
|
||||
pm.mu.Unlock()
|
||||
})
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
return pm.closeErr
|
||||
}
|
||||
|
||||
func (pm *ptyManager) Setsize(ws *pty.Winsize) error {
|
||||
pm.mu.RLock()
|
||||
defer pm.mu.RUnlock()
|
||||
if pm.closed {
|
||||
return errors.New("pty is closed")
|
||||
}
|
||||
return pty.Setsize(pm.file, ws)
|
||||
}
|
||||
|
||||
func (pm *ptyManager) File() *os.File {
|
||||
return pm.file
|
||||
}
|
||||
|
||||
// detectSuPtySupport checks if su supports the --pty flag
|
||||
func (s *Server) detectSuPtySupport(ctx context.Context) bool {
|
||||
ctx, cancel := context.WithTimeout(ctx, 500*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "su", "--help")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
log.Debugf("su --help failed (may not support --help): %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
supported := strings.Contains(string(output), "--pty")
|
||||
log.Debugf("su --pty support detected: %v", supported)
|
||||
return supported
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching
|
||||
func (s *Server) createSuCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, error) {
|
||||
suPath, err := exec.LookPath("su")
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("su command not available: %w", err)
|
||||
}
|
||||
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
return nil, fmt.Errorf("no command specified for su execution")
|
||||
}
|
||||
|
||||
args := []string{"-l"}
|
||||
if hasPty && s.suSupportsPty {
|
||||
args = append(args, "--pty")
|
||||
}
|
||||
args = append(args, localUser.Username, "-c", command)
|
||||
|
||||
cmd := exec.CommandContext(session.Context(), suPath, args...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
|
||||
return cmd, nil
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-l"}
|
||||
}
|
||||
return []string{shell, "-l", "-c", cmdString}
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Unix
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
termType := ptyReq.Term
|
||||
if termType == "" {
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
execCmd.Env = append(execCmd.Env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
execCmd, err := s.createPtyCommand(privilegeResult, ptyReq, session)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty command creation failed: %v", err)
|
||||
errorMsg := "User switching failed - login command not available\r\n"
|
||||
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Infof("starting interactive shell: %s", execCmd.Path)
|
||||
return s.runPtyCommand(logger, session, execCmd, ptyReq, winCh)
|
||||
}
|
||||
|
||||
// runPtyCommand runs a command with PTY management (common code for interactive and command execution)
|
||||
func (s *Server) runPtyCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
ptmx, err := s.startPtyCommandWithSize(execCmd, ptyReq)
|
||||
if err != nil {
|
||||
logger.Errorf("Pty start failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
ptyMgr := newPtyManager(ptmx)
|
||||
defer func() {
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
go s.handlePtyWindowResize(logger, session, ptyMgr, winCh)
|
||||
s.handlePtyIO(logger, session, ptyMgr)
|
||||
s.waitForPtyCompletion(logger, session, execCmd, ptyMgr)
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) startPtyCommandWithSize(execCmd *exec.Cmd, ptyReq ssh.Pty) (*os.File, error) {
|
||||
winSize := &pty.Winsize{
|
||||
Cols: uint16(ptyReq.Window.Width),
|
||||
Rows: uint16(ptyReq.Window.Height),
|
||||
}
|
||||
if winSize.Cols == 0 {
|
||||
winSize.Cols = 80
|
||||
}
|
||||
if winSize.Rows == 0 {
|
||||
winSize.Rows = 24
|
||||
}
|
||||
|
||||
ptmx, err := pty.StartWithSize(execCmd, winSize)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("start Pty: %w", err)
|
||||
}
|
||||
|
||||
return ptmx, nil
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWindowResize(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager, winCh <-chan ssh.Window) {
|
||||
for {
|
||||
select {
|
||||
case <-session.Context().Done():
|
||||
return
|
||||
case win, ok := <-winCh:
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
if err := ptyMgr.Setsize(&pty.Winsize{Rows: uint16(win.Height), Cols: uint16(win.Width)}); err != nil {
|
||||
logger.Debugf("Pty resize to %dx%d: %v", win.Width, win.Height, err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyIO(logger *log.Entry, session ssh.Session, ptyMgr *ptyManager) {
|
||||
ptmx := ptyMgr.File()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(ptmx, session); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty input copy error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Debugf("session close error: %v", err)
|
||||
}
|
||||
}()
|
||||
if _, err := io.Copy(session, ptmx); err != nil {
|
||||
if !errors.Is(err, io.EOF) && !errors.Is(err, syscall.EIO) {
|
||||
logger.Warnf("Pty output copy error: %v", err)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
func (s *Server) waitForPtyCompletion(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager) {
|
||||
ctx := session.Context()
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- execCmd.Wait()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
s.handlePtySessionCancellation(logger, session, execCmd, ptyMgr, done)
|
||||
case err := <-done:
|
||||
s.handlePtyCommandCompletion(logger, session, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtySessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, ptyMgr *ptyManager, done <-chan error) {
|
||||
logger.Debugf("Pty session cancelled, terminating command")
|
||||
if err := ptyMgr.Close(); err != nil {
|
||||
logger.Debugf("Pty close during session cancellation: %v", err)
|
||||
}
|
||||
|
||||
s.killProcessGroup(execCmd)
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command terminated after session cancellation with error: %v", err)
|
||||
} else {
|
||||
logger.Debugf("Pty command terminated after session cancellation")
|
||||
}
|
||||
case <-time.After(5 * time.Second):
|
||||
logger.Warnf("Pty command did not terminate within 5 seconds after session cancellation")
|
||||
}
|
||||
|
||||
if err := session.Exit(130); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyCommandCompletion(logger *log.Entry, session ssh.Session, err error) {
|
||||
if err != nil {
|
||||
logger.Debugf("Pty command execution failed: %v", err)
|
||||
s.handleSessionExit(session, err, logger)
|
||||
return
|
||||
}
|
||||
|
||||
// Normal completion
|
||||
logger.Debugf("Pty command completed successfully")
|
||||
if err := session.Exit(0); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) setupProcessGroup(cmd *exec.Cmd) {
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Setpgid: true,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
pgid := cmd.Process.Pid
|
||||
|
||||
if err := syscall.Kill(-pgid, syscall.SIGTERM); err != nil {
|
||||
logger.Debugf("kill process group SIGTERM: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
const gracePeriod = 500 * time.Millisecond
|
||||
const checkInterval = 50 * time.Millisecond
|
||||
|
||||
ticker := time.NewTicker(checkInterval)
|
||||
defer ticker.Stop()
|
||||
|
||||
timeout := time.After(gracePeriod)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-timeout:
|
||||
if err := syscall.Kill(-pgid, syscall.SIGKILL); err != nil {
|
||||
logger.Debugf("kill process group SIGKILL: %v", err)
|
||||
}
|
||||
return
|
||||
case <-ticker.C:
|
||||
if err := syscall.Kill(-pgid, 0); err != nil {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
430
client/ssh/server/command_execution_windows.go
Normal file
430
client/ssh/server/command_execution_windows.go
Normal file
@@ -0,0 +1,430 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
"golang.org/x/sys/windows/registry"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh/server/winpty"
|
||||
)
|
||||
|
||||
// getUserEnvironment retrieves the Windows environment for the target user.
|
||||
// Follows OpenSSH's resilient approach with graceful degradation on failures.
|
||||
func (s *Server) getUserEnvironment(username, domain string) ([]string, error) {
|
||||
userToken, err := s.getUserToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.getUserEnvironmentWithToken(userToken, username, domain)
|
||||
}
|
||||
|
||||
// getUserEnvironmentWithToken retrieves the Windows environment using an existing token.
|
||||
func (s *Server) getUserEnvironmentWithToken(userToken windows.Handle, username, domain string) ([]string, error) {
|
||||
userProfile, err := s.loadUserProfile(userToken, username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to load user profile for %s\\%s: %v", domain, username, err)
|
||||
userProfile = fmt.Sprintf("C:\\Users\\%s", username)
|
||||
}
|
||||
|
||||
envMap := make(map[string]string)
|
||||
|
||||
if err := s.loadSystemEnvironment(envMap); err != nil {
|
||||
log.Debugf("failed to load system environment from registry: %v", err)
|
||||
}
|
||||
|
||||
s.setUserEnvironmentVariables(envMap, userProfile, username, domain)
|
||||
|
||||
var env []string
|
||||
for key, value := range envMap {
|
||||
env = append(env, key+"="+value)
|
||||
}
|
||||
|
||||
return env, nil
|
||||
}
|
||||
|
||||
// getUserToken creates a user token for the specified user.
|
||||
func (s *Server) getUserToken(username, domain string) (windows.Handle, error) {
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
token, err := privilegeDropper.createToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("generate S4U user token: %w", err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// loadUserProfile loads the Windows user profile and returns the profile path.
|
||||
func (s *Server) loadUserProfile(userToken windows.Handle, username, domain string) (string, error) {
|
||||
usernamePtr, err := windows.UTF16PtrFromString(username)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert username to UTF-16: %w", err)
|
||||
}
|
||||
|
||||
var domainUTF16 *uint16
|
||||
if domain != "" && domain != "." {
|
||||
domainUTF16, err = windows.UTF16PtrFromString(domain)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert domain to UTF-16: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
type profileInfo struct {
|
||||
dwSize uint32
|
||||
dwFlags uint32
|
||||
lpUserName *uint16
|
||||
lpProfilePath *uint16
|
||||
lpDefaultPath *uint16
|
||||
lpServerName *uint16
|
||||
lpPolicyPath *uint16
|
||||
hProfile windows.Handle
|
||||
}
|
||||
|
||||
const PI_NOUI = 0x00000001
|
||||
|
||||
profile := profileInfo{
|
||||
dwSize: uint32(unsafe.Sizeof(profileInfo{})),
|
||||
dwFlags: PI_NOUI,
|
||||
lpUserName: usernamePtr,
|
||||
lpServerName: domainUTF16,
|
||||
}
|
||||
|
||||
userenv := windows.NewLazySystemDLL("userenv.dll")
|
||||
loadUserProfileW := userenv.NewProc("LoadUserProfileW")
|
||||
|
||||
ret, _, err := loadUserProfileW.Call(
|
||||
uintptr(userToken),
|
||||
uintptr(unsafe.Pointer(&profile)),
|
||||
)
|
||||
|
||||
if ret == 0 {
|
||||
return "", fmt.Errorf("LoadUserProfileW: %w", err)
|
||||
}
|
||||
|
||||
if profile.lpProfilePath == nil {
|
||||
return "", fmt.Errorf("LoadUserProfileW returned null profile path")
|
||||
}
|
||||
|
||||
profilePath := windows.UTF16PtrToString(profile.lpProfilePath)
|
||||
return profilePath, nil
|
||||
}
|
||||
|
||||
// loadSystemEnvironment loads system-wide environment variables from registry.
|
||||
func (s *Server) loadSystemEnvironment(envMap map[string]string) error {
|
||||
key, err := registry.OpenKey(registry.LOCAL_MACHINE,
|
||||
`SYSTEM\CurrentControlSet\Control\Session Manager\Environment`,
|
||||
registry.QUERY_VALUE)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open system environment registry key: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := key.Close(); err != nil {
|
||||
log.Debugf("close registry key: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
return s.readRegistryEnvironment(key, envMap)
|
||||
}
|
||||
|
||||
// readRegistryEnvironment reads environment variables from a registry key.
|
||||
func (s *Server) readRegistryEnvironment(key registry.Key, envMap map[string]string) error {
|
||||
names, err := key.ReadValueNames(0)
|
||||
if err != nil {
|
||||
return fmt.Errorf("read registry value names: %w", err)
|
||||
}
|
||||
|
||||
for _, name := range names {
|
||||
value, valueType, err := key.GetStringValue(name)
|
||||
if err != nil {
|
||||
log.Debugf("failed to read registry value %s: %v", name, err)
|
||||
continue
|
||||
}
|
||||
|
||||
finalValue := s.expandRegistryValue(value, valueType, name)
|
||||
s.setEnvironmentVariable(envMap, name, finalValue)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// expandRegistryValue expands registry values if they contain environment variables.
|
||||
func (s *Server) expandRegistryValue(value string, valueType uint32, name string) string {
|
||||
if valueType != registry.EXPAND_SZ {
|
||||
return value
|
||||
}
|
||||
|
||||
sourcePtr := windows.StringToUTF16Ptr(value)
|
||||
expandedBuffer := make([]uint16, 1024)
|
||||
expandedLen, err := windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
||||
if err != nil {
|
||||
log.Debugf("failed to expand environment string for %s: %v", name, err)
|
||||
return value
|
||||
}
|
||||
|
||||
// If buffer was too small, retry with larger buffer
|
||||
if expandedLen > uint32(len(expandedBuffer)) {
|
||||
expandedBuffer = make([]uint16, expandedLen)
|
||||
expandedLen, err = windows.ExpandEnvironmentStrings(sourcePtr, &expandedBuffer[0], uint32(len(expandedBuffer)))
|
||||
if err != nil {
|
||||
log.Debugf("failed to expand environment string for %s on retry: %v", name, err)
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
if expandedLen > 0 && expandedLen <= uint32(len(expandedBuffer)) {
|
||||
return windows.UTF16ToString(expandedBuffer[:expandedLen-1])
|
||||
}
|
||||
return value
|
||||
}
|
||||
|
||||
// setEnvironmentVariable sets an environment variable with special handling for PATH.
|
||||
func (s *Server) setEnvironmentVariable(envMap map[string]string, name, value string) {
|
||||
upperName := strings.ToUpper(name)
|
||||
|
||||
if upperName == "PATH" {
|
||||
if existing, exists := envMap["PATH"]; exists && existing != value {
|
||||
envMap["PATH"] = existing + ";" + value
|
||||
} else {
|
||||
envMap["PATH"] = value
|
||||
}
|
||||
} else {
|
||||
envMap[upperName] = value
|
||||
}
|
||||
}
|
||||
|
||||
// setUserEnvironmentVariables sets critical user-specific environment variables.
|
||||
func (s *Server) setUserEnvironmentVariables(envMap map[string]string, userProfile, username, domain string) {
|
||||
envMap["USERPROFILE"] = userProfile
|
||||
|
||||
if len(userProfile) >= 2 && userProfile[1] == ':' {
|
||||
envMap["HOMEDRIVE"] = userProfile[:2]
|
||||
envMap["HOMEPATH"] = userProfile[2:]
|
||||
}
|
||||
|
||||
envMap["APPDATA"] = filepath.Join(userProfile, "AppData", "Roaming")
|
||||
envMap["LOCALAPPDATA"] = filepath.Join(userProfile, "AppData", "Local")
|
||||
|
||||
tempDir := filepath.Join(userProfile, "AppData", "Local", "Temp")
|
||||
envMap["TEMP"] = tempDir
|
||||
envMap["TMP"] = tempDir
|
||||
|
||||
envMap["USERNAME"] = username
|
||||
if domain != "" && domain != "." {
|
||||
envMap["USERDOMAIN"] = domain
|
||||
envMap["USERDNSDOMAIN"] = domain
|
||||
}
|
||||
|
||||
systemVars := []string{
|
||||
"PROCESSOR_ARCHITECTURE", "PROCESSOR_IDENTIFIER", "PROCESSOR_LEVEL", "PROCESSOR_REVISION",
|
||||
"SYSTEMDRIVE", "SYSTEMROOT", "WINDIR", "COMPUTERNAME", "OS", "PATHEXT",
|
||||
"PROGRAMFILES", "PROGRAMDATA", "ALLUSERSPROFILE", "COMSPEC",
|
||||
}
|
||||
|
||||
for _, sysVar := range systemVars {
|
||||
if sysValue := os.Getenv(sysVar); sysValue != "" {
|
||||
envMap[sysVar] = sysValue
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// prepareCommandEnv prepares environment variables for command execution on Windows
|
||||
func (s *Server) prepareCommandEnv(localUser *user.User, session ssh.Session) []string {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
userEnv, err := s.getUserEnvironment(username, domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using fallback: %v", domain, username, err)
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
env := userEnv
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
if privilegeResult.User == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
return false
|
||||
}
|
||||
|
||||
cmd := session.Command()
|
||||
shell := getUserShell(privilegeResult.User.Uid)
|
||||
|
||||
if len(cmd) == 0 {
|
||||
logger.Infof("starting interactive shell: %s", shell)
|
||||
} else {
|
||||
logger.Infof("executing command: %s", safeLogCommand(cmd))
|
||||
}
|
||||
|
||||
s.handlePtyWithUserSwitching(logger, session, privilegeResult, ptyReq, winCh, cmd)
|
||||
return true
|
||||
}
|
||||
|
||||
// getShellCommandArgs returns the shell command and arguments for executing a command string
|
||||
func (s *Server) getShellCommandArgs(shell, cmdString string) []string {
|
||||
if cmdString == "" {
|
||||
return []string{shell, "-NoLogo"}
|
||||
}
|
||||
return []string{shell, "-Command", cmdString}
|
||||
}
|
||||
|
||||
func (s *Server) handlePtyWithUserSwitching(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, _ <-chan ssh.Window, _ []string) {
|
||||
logger.Info("starting interactive shell")
|
||||
s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, session.RawCommand())
|
||||
}
|
||||
|
||||
type PtyExecutionRequest struct {
|
||||
Shell string
|
||||
Command string
|
||||
Width int
|
||||
Height int
|
||||
Username string
|
||||
Domain string
|
||||
}
|
||||
|
||||
func executePtyCommandWithUserToken(ctx context.Context, session ssh.Session, req PtyExecutionRequest) error {
|
||||
log.Tracef("executing Windows ConPty command with user switching: shell=%s, command=%s, user=%s\\%s, size=%dx%d",
|
||||
req.Shell, req.Command, req.Domain, req.Username, req.Width, req.Height)
|
||||
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
userToken, err := privilegeDropper.createToken(req.Username, req.Domain)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create user token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(userToken); err != nil {
|
||||
log.Debugf("close user token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
server := &Server{}
|
||||
userEnv, err := server.getUserEnvironmentWithToken(userToken, req.Username, req.Domain)
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user environment for %s\\%s, using system environment: %v", req.Domain, req.Username, err)
|
||||
userEnv = os.Environ()
|
||||
}
|
||||
|
||||
workingDir := getUserHomeFromEnv(userEnv)
|
||||
if workingDir == "" {
|
||||
workingDir = fmt.Sprintf(`C:\Users\%s`, req.Username)
|
||||
}
|
||||
|
||||
ptyConfig := winpty.PtyConfig{
|
||||
Shell: req.Shell,
|
||||
Command: req.Command,
|
||||
Width: req.Width,
|
||||
Height: req.Height,
|
||||
WorkingDir: workingDir,
|
||||
}
|
||||
|
||||
userConfig := winpty.UserConfig{
|
||||
Token: userToken,
|
||||
Environment: userEnv,
|
||||
}
|
||||
|
||||
log.Debugf("executePtyCommandWithUserToken: calling winpty execution with working dir: %s", workingDir)
|
||||
return winpty.ExecutePtyWithUserToken(ctx, session, ptyConfig, userConfig)
|
||||
}
|
||||
|
||||
func getUserHomeFromEnv(env []string) string {
|
||||
for _, envVar := range env {
|
||||
if len(envVar) > 12 && envVar[:12] == "USERPROFILE=" {
|
||||
return envVar[12:]
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
func (s *Server) setupProcessGroup(_ *exec.Cmd) {
|
||||
// Windows doesn't support process groups in the same way as Unix
|
||||
// Process creation groups are handled differently
|
||||
}
|
||||
|
||||
func (s *Server) killProcessGroup(cmd *exec.Cmd) {
|
||||
if cmd.Process == nil {
|
||||
return
|
||||
}
|
||||
|
||||
logger := log.WithField("pid", cmd.Process.Pid)
|
||||
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
logger.Debugf("kill process failed: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// detectSuPtySupport always returns false on Windows as su is not available
|
||||
func (s *Server) detectSuPtySupport(context.Context) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// executeCommandWithPty executes a command with PTY allocation on Windows using ConPty
|
||||
func (s *Server) executeCommandWithPty(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, winCh <-chan ssh.Window) bool {
|
||||
command := session.RawCommand()
|
||||
if command == "" {
|
||||
logger.Error("no command specified for PTY execution")
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
return s.executeConPtyCommand(logger, session, privilegeResult, ptyReq, command)
|
||||
}
|
||||
|
||||
// executeConPtyCommand executes a command using ConPty (common for interactive and command execution)
|
||||
func (s *Server) executeConPtyCommand(logger *log.Entry, session ssh.Session, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, command string) bool {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
logger.Errorf("no user in privilege result")
|
||||
return false
|
||||
}
|
||||
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
shell := getUserShell(localUser.Uid)
|
||||
|
||||
req := PtyExecutionRequest{
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Width: ptyReq.Window.Width,
|
||||
Height: ptyReq.Window.Height,
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
}
|
||||
|
||||
if err := executePtyCommandWithUserToken(session.Context(), session, req); err != nil {
|
||||
logger.Errorf("ConPty execution failed: %v", err)
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
logger.Debug("ConPty execution completed")
|
||||
return true
|
||||
}
|
||||
722
client/ssh/server/compatibility_test.go
Normal file
722
client/ssh/server/compatibility_test.go
Normal file
@@ -0,0 +1,722 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/ed25519"
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
// TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client
|
||||
func TestSSHServerCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH compatibility tests in short mode")
|
||||
}
|
||||
|
||||
// Check if ssh binary is available
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Set up SSH server - use our existing key generation for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate OpenSSH-compatible keys for client
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Create temporary key files for SSH client
|
||||
clientKeyFile, cleanupKey := createTempKeyFileFromBytes(t, clientPrivKeyOpenSSH)
|
||||
defer cleanupKey()
|
||||
|
||||
// Extract host and port from server address
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get appropriate user for SSH connection (handle system accounts)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("basic command execution", func(t *testing.T) {
|
||||
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
|
||||
})
|
||||
|
||||
t.Run("interactive command", func(t *testing.T) {
|
||||
testSSHInteractiveCommand(t, host, portStr, clientKeyFile)
|
||||
})
|
||||
|
||||
t.Run("port forwarding", func(t *testing.T) {
|
||||
testSSHPortForwarding(t, host, portStr, clientKeyFile)
|
||||
})
|
||||
}
|
||||
|
||||
// testSSHCommandExecutionWithUser tests basic command execution with system SSH client using specified user.
|
||||
func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username string) {
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "hello_world")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
t.Logf("SSH command failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully")
|
||||
}
|
||||
|
||||
// testSSHInteractiveCommand tests interactive shell session.
|
||||
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
cmd := exec.CommandContext(ctx, "ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host))
|
||||
|
||||
stdin, err := cmd.StdinPipe()
|
||||
if err != nil {
|
||||
t.Skipf("Cannot create stdin pipe: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
t.Skipf("Cannot create stdout pipe: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
t.Logf("Cannot start SSH session: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
go func() {
|
||||
defer func() {
|
||||
if err := stdin.Close(); err != nil {
|
||||
t.Logf("stdin close error: %v", err)
|
||||
}
|
||||
}()
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if _, err := stdin.Write([]byte("echo interactive_test\n")); err != nil {
|
||||
t.Logf("stdin write error: %v", err)
|
||||
}
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
if _, err := stdin.Write([]byte("exit\n")); err != nil {
|
||||
t.Logf("stdin write error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
output, err := io.ReadAll(stdout)
|
||||
if err != nil {
|
||||
t.Logf("Cannot read SSH output: %v", err)
|
||||
}
|
||||
|
||||
err = cmd.Wait()
|
||||
if err != nil {
|
||||
t.Logf("SSH interactive session error: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, string(output), "interactive_test", "Interactive SSH session should work")
|
||||
}
|
||||
|
||||
// testSSHPortForwarding tests port forwarding compatibility.
|
||||
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
defer testServer.Close()
|
||||
|
||||
testServerAddr := testServer.Addr().String()
|
||||
expectedResponse := "HTTP/1.1 200 OK\r\nContent-Length: 21\r\n\r\nCompatibility Test OK"
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := testServer.Accept()
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
go func(c net.Conn) {
|
||||
defer func() {
|
||||
if err := c.Close(); err != nil {
|
||||
t.Logf("test server connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
buf := make([]byte, 1024)
|
||||
if _, err := c.Read(buf); err != nil {
|
||||
t.Logf("Test server read error: %v", err)
|
||||
}
|
||||
if _, err := c.Write([]byte(expectedResponse)); err != nil {
|
||||
t.Logf("Test server write error: %v", err)
|
||||
}
|
||||
}(conn)
|
||||
}
|
||||
}()
|
||||
|
||||
localListener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
localAddr := localListener.Addr().String()
|
||||
localListener.Close()
|
||||
|
||||
_, localPort, err := net.SplitHostPort(localAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 15*time.Second)
|
||||
defer cancel()
|
||||
|
||||
forwardSpec := fmt.Sprintf("%s:%s", localPort, testServerAddr)
|
||||
cmd := exec.CommandContext(ctx, "ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-L", forwardSpec,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-N",
|
||||
fmt.Sprintf("%s@%s", username, host))
|
||||
|
||||
err = cmd.Start()
|
||||
if err != nil {
|
||||
t.Logf("Cannot start SSH port forwarding: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if cmd.Process != nil {
|
||||
if err := cmd.Process.Kill(); err != nil {
|
||||
t.Logf("process kill error: %v", err)
|
||||
}
|
||||
}
|
||||
if err := cmd.Wait(); err != nil {
|
||||
t.Logf("process wait after kill: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
time.Sleep(500 * time.Millisecond)
|
||||
|
||||
conn, err := net.DialTimeout("tcp", localAddr, 3*time.Second)
|
||||
if err != nil {
|
||||
t.Logf("Cannot connect to forwarded port: %v", err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("forwarded connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
request := "GET / HTTP/1.1\r\nHost: localhost\r\nConnection: close\r\n\r\n"
|
||||
_, err = conn.Write([]byte(request))
|
||||
require.NoError(t, err)
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
|
||||
log.Debugf("failed to set read deadline: %v", err)
|
||||
}
|
||||
response := make([]byte, len(expectedResponse))
|
||||
n, err := io.ReadFull(conn, response)
|
||||
if err != nil {
|
||||
t.Logf("Cannot read forwarded response: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
assert.Equal(t, len(expectedResponse), n, "Should read expected number of bytes")
|
||||
assert.Equal(t, expectedResponse, string(response), "Should get correct HTTP response through SSH port forwarding")
|
||||
}
|
||||
|
||||
// isSSHClientAvailable checks if the ssh binary is available
|
||||
func isSSHClientAvailable() bool {
|
||||
_, err := exec.LookPath("ssh")
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// generateOpenSSHKey generates an ED25519 key in OpenSSH format that the system SSH client can use.
|
||||
func generateOpenSSHKey(t *testing.T) ([]byte, []byte, error) {
|
||||
// Check if ssh-keygen is available
|
||||
if _, err := exec.LookPath("ssh-keygen"); err != nil {
|
||||
// Fall back to our existing key generation and try to convert
|
||||
return generateOpenSSHKeyFallback()
|
||||
}
|
||||
|
||||
// Create temporary file for ssh-keygen
|
||||
tempFile, err := os.CreateTemp("", "ssh_keygen_*")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create temp file: %w", err)
|
||||
}
|
||||
keyPath := tempFile.Name()
|
||||
tempFile.Close()
|
||||
|
||||
// Remove the temp file so ssh-keygen can create it
|
||||
if err := os.Remove(keyPath); err != nil {
|
||||
t.Logf("failed to remove key file: %v", err)
|
||||
}
|
||||
|
||||
// Clean up temp files
|
||||
defer func() {
|
||||
if err := os.Remove(keyPath); err != nil {
|
||||
t.Logf("failed to cleanup key file: %v", err)
|
||||
}
|
||||
if err := os.Remove(keyPath + ".pub"); err != nil {
|
||||
t.Logf("failed to cleanup public key file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Generate key using ssh-keygen
|
||||
cmd := exec.Command("ssh-keygen", "-t", "ed25519", "-f", keyPath, "-N", "", "-q")
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("ssh-keygen failed: %w, output: %s", err, string(output))
|
||||
}
|
||||
|
||||
// Read private key
|
||||
privKeyBytes, err := os.ReadFile(keyPath)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read private key: %w", err)
|
||||
}
|
||||
|
||||
// Read public key
|
||||
pubKeyBytes, err := os.ReadFile(keyPath + ".pub")
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("read public key: %w", err)
|
||||
}
|
||||
|
||||
return privKeyBytes, pubKeyBytes, nil
|
||||
}
|
||||
|
||||
// generateOpenSSHKeyFallback falls back to generating keys using our existing method
|
||||
func generateOpenSSHKeyFallback() ([]byte, []byte, error) {
|
||||
// Generate shared.ED25519 key pair using our existing method
|
||||
_, privKey, err := ed25519.GenerateKey(rand.Reader)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate key: %w", err)
|
||||
}
|
||||
|
||||
// Convert to SSH format
|
||||
sshPrivKey, err := ssh.NewSignerFromKey(privKey)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("create signer: %w", err)
|
||||
}
|
||||
|
||||
// For the fallback, just use our PKCS#8 format and hope it works
|
||||
// This won't be in OpenSSH format but might still work with some SSH clients
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("generate fallback key: %w", err)
|
||||
}
|
||||
|
||||
// Get public key in SSH format
|
||||
sshPubKey := ssh.MarshalAuthorizedKey(sshPrivKey.PublicKey())
|
||||
|
||||
return hostKey, sshPubKey, nil
|
||||
}
|
||||
|
||||
// createTempKeyFileFromBytes creates a temporary SSH private key file from raw bytes
|
||||
func createTempKeyFileFromBytes(t *testing.T, keyBytes []byte) (string, func()) {
|
||||
t.Helper()
|
||||
|
||||
tempFile, err := os.CreateTemp("", "ssh_test_key_*")
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = tempFile.Write(keyBytes)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = tempFile.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Set proper permissions for SSH key (readable by owner only)
|
||||
err = os.Chmod(tempFile.Name(), 0600)
|
||||
require.NoError(t, err)
|
||||
|
||||
cleanup := func() {
|
||||
_ = os.Remove(tempFile.Name())
|
||||
}
|
||||
|
||||
return tempFile.Name(), cleanup
|
||||
}
|
||||
|
||||
// createTempKeyFile creates a temporary SSH private key file (for backward compatibility)
|
||||
func createTempKeyFile(t *testing.T, privateKey []byte) (string, func()) {
|
||||
return createTempKeyFileFromBytes(t, privateKey)
|
||||
}
|
||||
|
||||
// TestSSHServerFeatureCompatibility tests specific SSH features for compatibility
|
||||
func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH feature compatibility tests in short mode")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Test various SSH features
|
||||
testCases := []struct {
|
||||
name string
|
||||
testFunc func(t *testing.T, host, port, keyFile string)
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "command_with_flags",
|
||||
testFunc: testCommandWithFlags,
|
||||
description: "Commands with flags should work like standard SSH",
|
||||
},
|
||||
{
|
||||
name: "environment_variables",
|
||||
testFunc: testEnvironmentVariables,
|
||||
description: "Environment variables should be available",
|
||||
},
|
||||
{
|
||||
name: "exit_codes",
|
||||
testFunc: testExitCodes,
|
||||
description: "Exit codes should be properly handled",
|
||||
},
|
||||
}
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
tc.testFunc(t, host, portStr, clientKeyFile)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// testCommandWithFlags tests that commands with flags work properly
|
||||
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test ls with flags
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"ls", "-la", "/tmp")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Command with flags failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
// Should not be empty and should not contain error messages
|
||||
assert.NotEmpty(t, string(output), "ls -la should produce output")
|
||||
assert.NotContains(t, strings.ToLower(string(output)), "command not found", "Command should be executed")
|
||||
}
|
||||
|
||||
// testEnvironmentVariables tests that environment is properly set up
|
||||
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "$HOME")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Environment test failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
// HOME environment variable should be available
|
||||
homeOutput := strings.TrimSpace(string(output))
|
||||
assert.NotEmpty(t, homeOutput, "HOME environment variable should be set")
|
||||
assert.NotEqual(t, "$HOME", homeOutput, "Environment variable should be expanded")
|
||||
}
|
||||
|
||||
// testExitCodes tests that exit codes are properly handled
|
||||
func testExitCodes(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test successful command (exit code 0)
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"true") // always succeeds
|
||||
|
||||
err := cmd.Run()
|
||||
assert.NoError(t, err, "Command with exit code 0 should succeed")
|
||||
|
||||
// Test failing command (exit code 1)
|
||||
cmd = exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
"-p", port,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"false") // always fails
|
||||
|
||||
err = cmd.Run()
|
||||
assert.Error(t, err, "Command with exit code 1 should fail")
|
||||
|
||||
// Check if it's the right kind of error
|
||||
if exitError, ok := err.(*exec.ExitError); ok {
|
||||
assert.Equal(t, 1, exitError.ExitCode(), "Exit code should be preserved")
|
||||
}
|
||||
}
|
||||
|
||||
// TestSSHServerSecurityFeatures tests security-related SSH features
|
||||
func TestSSHServerSecurityFeatures(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping SSH security tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server with specific security settings
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("key_authentication", func(t *testing.T) {
|
||||
// Test that key authentication works
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "auth_success")
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Key authentication failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
assert.Contains(t, string(output), "auth_success", "Key authentication should work")
|
||||
})
|
||||
|
||||
t.Run("any_key_accepted_in_no_auth_mode", func(t *testing.T) {
|
||||
// Create a different key that shouldn't be accepted
|
||||
wrongKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongKeyFile, cleanupWrongKey := createTempKeyFile(t, wrongKey)
|
||||
defer cleanupWrongKey()
|
||||
|
||||
// Test that wrong key is rejected
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", wrongKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
"-o", "PasswordAuthentication=no",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
"echo", "should_not_work")
|
||||
|
||||
err = cmd.Run()
|
||||
assert.NoError(t, err, "Any key should work in no-auth mode")
|
||||
})
|
||||
}
|
||||
|
||||
// TestCrossPlatformCompatibility tests cross-platform behavior
|
||||
func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping cross-platform compatibility tests in short mode")
|
||||
}
|
||||
|
||||
if !isSSHClientAvailable() {
|
||||
t.Skip("SSH client not available on this system")
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientKeyFile, cleanupKey := createTempKeyFile(t, clientPrivKey)
|
||||
defer cleanupKey()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test platform-specific commands
|
||||
var testCommand string
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
testCommand = "echo %OS%"
|
||||
default:
|
||||
testCommand = "uname"
|
||||
}
|
||||
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", clientKeyFile,
|
||||
"-p", portStr,
|
||||
"-o", "StrictHostKeyChecking=no",
|
||||
"-o", "UserKnownHostsFile=/dev/null",
|
||||
"-o", "ConnectTimeout=5",
|
||||
fmt.Sprintf("%s@%s", username, host),
|
||||
testCommand)
|
||||
|
||||
output, err := cmd.CombinedOutput()
|
||||
if err != nil {
|
||||
t.Logf("Platform-specific command failed: %v", err)
|
||||
t.Logf("Output: %s", string(output))
|
||||
return
|
||||
}
|
||||
|
||||
outputStr := strings.TrimSpace(string(output))
|
||||
t.Logf("Platform command output: %s", outputStr)
|
||||
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
|
||||
}
|
||||
253
client/ssh/server/executor_unix.go
Normal file
253
client/ssh/server/executor_unix.go
Normal file
@@ -0,0 +1,253 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"runtime"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// Exit codes for executor process communication
|
||||
const (
|
||||
ExitCodeSuccess = 0
|
||||
ExitCodePrivilegeDropFail = 10
|
||||
ExitCodeShellExecFail = 11
|
||||
ExitCodeValidationFail = 12
|
||||
)
|
||||
|
||||
// ExecutorConfig holds configuration for the executor process
|
||||
type ExecutorConfig struct {
|
||||
UID uint32
|
||||
GID uint32
|
||||
Groups []uint32
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
PTY bool
|
||||
}
|
||||
|
||||
// PrivilegeDropper handles secure privilege dropping in child processes
|
||||
type PrivilegeDropper struct{}
|
||||
|
||||
// NewPrivilegeDropper creates a new privilege dropper
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
}
|
||||
|
||||
// CreateExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping
|
||||
func (pd *PrivilegeDropper) CreateExecutorCommand(ctx context.Context, config ExecutorConfig) (*exec.Cmd, error) {
|
||||
netbirdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get netbird executable path: %w", err)
|
||||
}
|
||||
|
||||
if err := pd.validatePrivileges(config.UID, config.GID); err != nil {
|
||||
return nil, fmt.Errorf("invalid privileges: %w", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"ssh", "exec",
|
||||
"--uid", fmt.Sprintf("%d", config.UID),
|
||||
"--gid", fmt.Sprintf("%d", config.GID),
|
||||
"--working-dir", config.WorkingDir,
|
||||
"--shell", config.Shell,
|
||||
}
|
||||
|
||||
for _, group := range config.Groups {
|
||||
args = append(args, "--groups", fmt.Sprintf("%d", group))
|
||||
}
|
||||
|
||||
if config.PTY {
|
||||
args = append(args, "--pty")
|
||||
}
|
||||
|
||||
if config.Command != "" {
|
||||
args = append(args, "--cmd", config.Command)
|
||||
}
|
||||
|
||||
// Log executor args safely - show all args except hide the command value
|
||||
safeArgs := make([]string, len(args))
|
||||
copy(safeArgs, args)
|
||||
for i := 0; i < len(safeArgs)-1; i++ {
|
||||
if safeArgs[i] == "--cmd" {
|
||||
cmdParts := strings.Fields(safeArgs[i+1])
|
||||
safeArgs[i+1] = safeLogCommand(cmdParts)
|
||||
break
|
||||
}
|
||||
}
|
||||
log.Tracef("creating executor command: %s %v", netbirdPath, safeArgs)
|
||||
return exec.CommandContext(ctx, netbirdPath, args...), nil
|
||||
}
|
||||
|
||||
// DropPrivileges performs privilege dropping with thread locking for security
|
||||
func (pd *PrivilegeDropper) DropPrivileges(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
|
||||
if err := pd.validatePrivileges(targetUID, targetGID); err != nil {
|
||||
return fmt.Errorf("invalid privileges: %w", err)
|
||||
}
|
||||
|
||||
runtime.LockOSThread()
|
||||
defer runtime.UnlockOSThread()
|
||||
|
||||
originalUID := os.Geteuid()
|
||||
originalGID := os.Getegid()
|
||||
|
||||
if originalUID != int(targetUID) || originalGID != int(targetGID) {
|
||||
if err := pd.setGroupsAndIDs(targetUID, targetGID, supplementaryGroups); err != nil {
|
||||
return fmt.Errorf("set groups and IDs: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := pd.validatePrivilegeDropSuccess(targetUID, targetGID, originalUID, originalGID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Tracef("successfully dropped privileges to UID=%d, GID=%d", targetUID, targetGID)
|
||||
return nil
|
||||
}
|
||||
|
||||
// setGroupsAndIDs sets the supplementary groups, GID, and UID
|
||||
func (pd *PrivilegeDropper) setGroupsAndIDs(targetUID, targetGID uint32, supplementaryGroups []uint32) error {
|
||||
groups := make([]int, len(supplementaryGroups))
|
||||
for i, g := range supplementaryGroups {
|
||||
groups[i] = int(g)
|
||||
}
|
||||
|
||||
if runtime.GOOS == "darwin" || runtime.GOOS == "freebsd" {
|
||||
if len(groups) == 0 || groups[0] != int(targetGID) {
|
||||
groups = append([]int{int(targetGID)}, groups...)
|
||||
}
|
||||
}
|
||||
|
||||
if err := syscall.Setgroups(groups); err != nil {
|
||||
return fmt.Errorf("setgroups to %v: %w", groups, err)
|
||||
}
|
||||
|
||||
if err := syscall.Setgid(int(targetGID)); err != nil {
|
||||
return fmt.Errorf("setgid to %d: %w", targetGID, err)
|
||||
}
|
||||
|
||||
if err := syscall.Setuid(int(targetUID)); err != nil {
|
||||
return fmt.Errorf("setuid to %d: %w", targetUID, err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePrivilegeDropSuccess validates that privilege dropping was successful
|
||||
func (pd *PrivilegeDropper) validatePrivilegeDropSuccess(targetUID, targetGID uint32, originalUID, originalGID int) error {
|
||||
if err := pd.validatePrivilegeDropReversibility(targetUID, targetGID, originalUID, originalGID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := pd.validateCurrentPrivileges(targetUID, targetGID); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// validatePrivilegeDropReversibility ensures privileges cannot be restored
|
||||
func (pd *PrivilegeDropper) validatePrivilegeDropReversibility(targetUID, targetGID uint32, originalUID, originalGID int) error {
|
||||
if originalGID != int(targetGID) {
|
||||
if err := syscall.Setegid(originalGID); err == nil {
|
||||
return fmt.Errorf("privilege drop validation failed: able to restore original GID %d", originalGID)
|
||||
}
|
||||
}
|
||||
if originalUID != int(targetUID) {
|
||||
if err := syscall.Seteuid(originalUID); err == nil {
|
||||
return fmt.Errorf("privilege drop validation failed: able to restore original UID %d", originalUID)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateCurrentPrivileges validates the current UID and GID match the target
|
||||
func (pd *PrivilegeDropper) validateCurrentPrivileges(targetUID, targetGID uint32) error {
|
||||
currentUID := os.Geteuid()
|
||||
if currentUID != int(targetUID) {
|
||||
return fmt.Errorf("privilege drop validation failed: current UID %d, expected %d", currentUID, targetUID)
|
||||
}
|
||||
|
||||
currentGID := os.Getegid()
|
||||
if currentGID != int(targetGID) {
|
||||
return fmt.Errorf("privilege drop validation failed: current GID %d, expected %d", currentGID, targetGID)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// ExecuteWithPrivilegeDrop executes a command with privilege dropping, using exit codes to signal specific failures
|
||||
func (pd *PrivilegeDropper) ExecuteWithPrivilegeDrop(ctx context.Context, config ExecutorConfig) {
|
||||
log.Tracef("dropping privileges to UID=%d, GID=%d, groups=%v", config.UID, config.GID, config.Groups)
|
||||
|
||||
// TODO: Implement Pty support for executor path
|
||||
if config.PTY {
|
||||
config.PTY = false
|
||||
}
|
||||
|
||||
if err := pd.DropPrivileges(config.UID, config.GID, config.Groups); err != nil {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "privilege drop failed: %v\n", err)
|
||||
os.Exit(ExitCodePrivilegeDropFail)
|
||||
}
|
||||
|
||||
if config.WorkingDir != "" {
|
||||
if err := os.Chdir(config.WorkingDir); err != nil {
|
||||
log.Debugf("failed to change to working directory %s, continuing with current directory: %v", config.WorkingDir, err)
|
||||
}
|
||||
}
|
||||
|
||||
var execCmd *exec.Cmd
|
||||
if config.Command == "" {
|
||||
os.Exit(ExitCodeSuccess)
|
||||
}
|
||||
|
||||
execCmd = exec.CommandContext(ctx, config.Shell, "-c", config.Command)
|
||||
execCmd.Stdin = os.Stdin
|
||||
execCmd.Stdout = os.Stdout
|
||||
execCmd.Stderr = os.Stderr
|
||||
|
||||
cmdParts := strings.Fields(config.Command)
|
||||
safeCmd := safeLogCommand(cmdParts)
|
||||
log.Tracef("executing %s -c %s", execCmd.Path, safeCmd)
|
||||
if err := execCmd.Run(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
// Normal command exit with non-zero code - not an SSH execution error
|
||||
log.Tracef("command exited with code %d", exitError.ExitCode())
|
||||
os.Exit(exitError.ExitCode())
|
||||
}
|
||||
|
||||
// Actual execution failure (command not found, permission denied, etc.)
|
||||
log.Debugf("command execution failed: %v", err)
|
||||
os.Exit(ExitCodeShellExecFail)
|
||||
}
|
||||
|
||||
os.Exit(ExitCodeSuccess)
|
||||
}
|
||||
|
||||
// validatePrivileges validates that privilege dropping to the target UID/GID is allowed
|
||||
func (pd *PrivilegeDropper) validatePrivileges(uid, gid uint32) error {
|
||||
currentUID := uint32(os.Geteuid())
|
||||
currentGID := uint32(os.Getegid())
|
||||
|
||||
// Allow same-user operations (no privilege dropping needed)
|
||||
if uid == currentUID && gid == currentGID {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Only root can drop privileges to other users
|
||||
if currentUID != 0 {
|
||||
return fmt.Errorf("cannot drop privileges from non-root user (UID %d) to UID %d", currentUID, uid)
|
||||
}
|
||||
|
||||
// Root can drop to any user (including root itself)
|
||||
return nil
|
||||
}
|
||||
262
client/ssh/server/executor_unix_test.go
Normal file
262
client/ssh/server/executor_unix_test.go
Normal file
@@ -0,0 +1,262 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
currentUID := uint32(os.Geteuid())
|
||||
currentGID := uint32(os.Getegid())
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
uid uint32
|
||||
gid uint32
|
||||
wantErr bool
|
||||
}{
|
||||
{
|
||||
name: "same user - no privilege drop needed",
|
||||
uid: currentUID,
|
||||
gid: currentGID,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "non-root to different user should fail",
|
||||
uid: currentUID + 1, // Use a different UID to ensure it's actually different
|
||||
gid: currentGID + 1, // Use a different GID to ensure it's actually different
|
||||
wantErr: currentUID != 0, // Only fail if current user is not root
|
||||
},
|
||||
{
|
||||
name: "root can drop to any user",
|
||||
uid: 1000,
|
||||
gid: 1000,
|
||||
wantErr: false,
|
||||
},
|
||||
{
|
||||
name: "root can stay as root",
|
||||
uid: 0,
|
||||
gid: 0,
|
||||
wantErr: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Skip non-root tests when running as root, and root tests when not root
|
||||
if tt.name == "non-root to different user should fail" && currentUID == 0 {
|
||||
t.Skip("Skipping non-root test when running as root")
|
||||
}
|
||||
if (tt.name == "root can drop to any user" || tt.name == "root can stay as root") && currentUID != 0 {
|
||||
t.Skip("Skipping root test when not running as root")
|
||||
}
|
||||
|
||||
err := pd.validatePrivileges(tt.uid, tt.gid)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000, 1001},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "ls -la",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify the command is calling netbird ssh exec
|
||||
assert.Contains(t, cmd.Args, "ssh")
|
||||
assert.Contains(t, cmd.Args, "exec")
|
||||
assert.Contains(t, cmd.Args, "--uid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--gid")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "--groups")
|
||||
assert.Contains(t, cmd.Args, "1000")
|
||||
assert.Contains(t, cmd.Args, "1001")
|
||||
assert.Contains(t, cmd.Args, "--working-dir")
|
||||
assert.Contains(t, cmd.Args, "/home/testuser")
|
||||
assert.Contains(t, cmd.Args, "--shell")
|
||||
assert.Contains(t, cmd.Args, "/bin/bash")
|
||||
assert.Contains(t, cmd.Args, "--cmd")
|
||||
assert.Contains(t, cmd.Args, "ls -la")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
config := ExecutorConfig{
|
||||
UID: 1000,
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/home/testuser",
|
||||
Shell: "/bin/bash",
|
||||
Command: "",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, cmd)
|
||||
|
||||
// Verify no command mode (command is empty so no --cmd flag)
|
||||
assert.NotContains(t, cmd.Args, "--cmd")
|
||||
assert.NotContains(t, cmd.Args, "--interactive")
|
||||
}
|
||||
|
||||
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
|
||||
// This test requires root privileges and will be skipped if not running as root
|
||||
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
|
||||
if os.Geteuid() != 0 {
|
||||
t.Skip("This test requires root privileges")
|
||||
}
|
||||
|
||||
// Find a non-root user to test with
|
||||
testUser, err := findNonRootUser()
|
||||
if err != nil {
|
||||
t.Skip("No suitable non-root user found for testing")
|
||||
}
|
||||
|
||||
// Verify the user actually exists by looking it up again
|
||||
_, err = user.LookupId(testUser.Uid)
|
||||
if err != nil {
|
||||
t.Skipf("Test user %s (UID %s) does not exist on this system: %v", testUser.Username, testUser.Uid, err)
|
||||
}
|
||||
|
||||
uid64, err := strconv.ParseUint(testUser.Uid, 10, 32)
|
||||
require.NoError(t, err)
|
||||
targetUID := uint32(uid64)
|
||||
|
||||
gid64, err := strconv.ParseUint(testUser.Gid, 10, 32)
|
||||
require.NoError(t, err)
|
||||
targetGID := uint32(gid64)
|
||||
|
||||
// Test in a child process to avoid affecting the test runner
|
||||
if os.Getenv("TEST_PRIVILEGE_DROP") == "1" {
|
||||
pd := NewPrivilegeDropper()
|
||||
|
||||
// This should succeed
|
||||
err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify we are now running as the target user
|
||||
currentUID := uint32(os.Geteuid())
|
||||
currentGID := uint32(os.Getegid())
|
||||
|
||||
assert.Equal(t, targetUID, currentUID, "UID should match target")
|
||||
assert.Equal(t, targetGID, currentGID, "GID should match target")
|
||||
assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root")
|
||||
assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group")
|
||||
|
||||
return
|
||||
}
|
||||
|
||||
// Fork a child process to test privilege dropping
|
||||
cmd := os.Args[0]
|
||||
args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"}
|
||||
|
||||
env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1")
|
||||
|
||||
execCmd := exec.Command(cmd, args...)
|
||||
execCmd.Env = env
|
||||
|
||||
err = execCmd.Run()
|
||||
require.NoError(t, err, "Child process should succeed")
|
||||
}
|
||||
|
||||
// findNonRootUser finds any non-root user on the system for testing
|
||||
func findNonRootUser() (*user.User, error) {
|
||||
// Try common non-root users, but avoid "nobody" on macOS due to negative UID issues
|
||||
commonUsers := []string{"daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"}
|
||||
|
||||
for _, username := range commonUsers {
|
||||
if u, err := user.Lookup(username); err == nil {
|
||||
// Parse as signed integer first to handle negative UIDs
|
||||
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
// Skip negative UIDs (like nobody=-2 on macOS) and root
|
||||
if uid64 > 0 && uid64 != 0 {
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no common users found, try to find any regular user with UID > 100
|
||||
// This helps on macOS where regular users start at UID 501
|
||||
allUsers := []string{"vma", "user", "test", "admin"}
|
||||
for _, username := range allUsers {
|
||||
if u, err := user.Lookup(username); err == nil {
|
||||
uid64, err := strconv.ParseInt(u.Uid, 10, 32)
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
if uid64 > 100 { // Regular user
|
||||
return u, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// If no common users found, return an error
|
||||
return nil, fmt.Errorf("no suitable non-root user found on this system")
|
||||
}
|
||||
|
||||
func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) {
|
||||
pd := NewPrivilegeDropper()
|
||||
currentUID := uint32(os.Geteuid())
|
||||
|
||||
if currentUID == 0 {
|
||||
// When running as root, test that root can create commands for any user
|
||||
config := ExecutorConfig{
|
||||
UID: 1000, // Target non-root user
|
||||
GID: 1000,
|
||||
Groups: []uint32{1000},
|
||||
WorkingDir: "/tmp",
|
||||
Shell: "/bin/sh",
|
||||
Command: "echo test",
|
||||
}
|
||||
|
||||
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
assert.NoError(t, err, "Root should be able to create commands for any user")
|
||||
assert.NotNil(t, cmd)
|
||||
} else {
|
||||
// When running as non-root, test that we can't drop to a different user
|
||||
config := ExecutorConfig{
|
||||
UID: 0, // Try to target root
|
||||
GID: 0,
|
||||
Groups: []uint32{0},
|
||||
WorkingDir: "/tmp",
|
||||
Shell: "/bin/sh",
|
||||
Command: "echo test",
|
||||
}
|
||||
|
||||
_, err := pd.CreateExecutorCommand(context.Background(), config)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "cannot drop privileges")
|
||||
}
|
||||
}
|
||||
570
client/ssh/server/executor_windows.go
Normal file
570
client/ssh/server/executor_windows.go
Normal file
@@ -0,0 +1,570 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
const (
|
||||
ExitCodeSuccess = 0
|
||||
ExitCodeLogonFail = 10
|
||||
ExitCodeCreateProcessFail = 11
|
||||
ExitCodeWorkingDirFail = 12
|
||||
ExitCodeShellExecFail = 13
|
||||
ExitCodeValidationFail = 14
|
||||
)
|
||||
|
||||
type WindowsExecutorConfig struct {
|
||||
Username string
|
||||
Domain string
|
||||
WorkingDir string
|
||||
Shell string
|
||||
Command string
|
||||
Args []string
|
||||
Interactive bool
|
||||
Pty bool
|
||||
PtyWidth int
|
||||
PtyHeight int
|
||||
}
|
||||
|
||||
type PrivilegeDropper struct{}
|
||||
|
||||
func NewPrivilegeDropper() *PrivilegeDropper {
|
||||
return &PrivilegeDropper{}
|
||||
}
|
||||
|
||||
var (
|
||||
advapi32 = windows.NewLazyDLL("advapi32.dll")
|
||||
procAllocateLocallyUniqueId = advapi32.NewProc("AllocateLocallyUniqueId")
|
||||
)
|
||||
|
||||
const (
|
||||
logon32LogonNetwork = 3 // Network logon - no password required for authenticated users
|
||||
|
||||
// Common error messages
|
||||
commandFlag = "-Command"
|
||||
closeTokenErrorMsg = "close token error: %v" // #nosec G101 -- This is an error message template, not credentials
|
||||
convertUsernameError = "convert username to UTF16: %w"
|
||||
convertDomainError = "convert domain to UTF16: %w"
|
||||
)
|
||||
|
||||
// CreateWindowsExecutorCommand creates a Windows command with privilege dropping.
|
||||
// The caller must close the returned token handle after starting the process.
|
||||
func (pd *PrivilegeDropper) CreateWindowsExecutorCommand(ctx context.Context, config WindowsExecutorConfig) (*exec.Cmd, windows.Token, error) {
|
||||
if config.Username == "" {
|
||||
return nil, 0, errors.New("username cannot be empty")
|
||||
}
|
||||
if config.Shell == "" {
|
||||
return nil, 0, errors.New("shell cannot be empty")
|
||||
}
|
||||
|
||||
shell := config.Shell
|
||||
|
||||
var shellArgs []string
|
||||
if config.Command != "" {
|
||||
shellArgs = []string{shell, commandFlag, config.Command}
|
||||
} else {
|
||||
shellArgs = []string{shell}
|
||||
}
|
||||
|
||||
log.Tracef("creating Windows direct shell command: %s %v", shellArgs[0], shellArgs)
|
||||
|
||||
cmd, token, err := pd.CreateWindowsProcessAsUser(
|
||||
ctx, shellArgs[0], shellArgs, config.Username, config.Domain, config.WorkingDir)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create Windows process as user: %w", err)
|
||||
}
|
||||
|
||||
return cmd, token, nil
|
||||
}
|
||||
|
||||
const (
|
||||
// StatusSuccess represents successful LSA operation
|
||||
StatusSuccess = 0
|
||||
|
||||
// KerbS4ULogonType message type for domain users with Kerberos
|
||||
KerbS4ULogonType = 12
|
||||
// Msv10s4ulogontype message type for local users with MSV1_0
|
||||
Msv10s4ulogontype = 12
|
||||
|
||||
// MicrosoftKerberosNameA is the authentication package name for Kerberos
|
||||
MicrosoftKerberosNameA = "Kerberos"
|
||||
// Msv10packagename is the authentication package name for MSV1_0
|
||||
Msv10packagename = "MICROSOFT_AUTHENTICATION_PACKAGE_V1_0"
|
||||
|
||||
NameSamCompatible = 2
|
||||
NameUserPrincipal = 8
|
||||
NameCanonical = 7
|
||||
|
||||
maxUPNLen = 1024
|
||||
)
|
||||
|
||||
// kerbS4ULogon structure for S4U authentication (domain users)
|
||||
type kerbS4ULogon struct {
|
||||
MessageType uint32
|
||||
Flags uint32
|
||||
ClientUpn unicodeString
|
||||
ClientRealm unicodeString
|
||||
}
|
||||
|
||||
// msv10s4ulogon structure for S4U authentication (local users)
|
||||
type msv10s4ulogon struct {
|
||||
MessageType uint32
|
||||
Flags uint32
|
||||
UserPrincipalName unicodeString
|
||||
DomainName unicodeString
|
||||
}
|
||||
|
||||
// unicodeString structure
|
||||
type unicodeString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer *uint16
|
||||
}
|
||||
|
||||
// lsaString structure
|
||||
type lsaString struct {
|
||||
Length uint16
|
||||
MaximumLength uint16
|
||||
Buffer *byte
|
||||
}
|
||||
|
||||
// tokenSource structure
|
||||
type tokenSource struct {
|
||||
SourceName [8]byte
|
||||
SourceIdentifier windows.LUID
|
||||
}
|
||||
|
||||
// quotaLimits structure
|
||||
type quotaLimits struct {
|
||||
PagedPoolLimit uint32
|
||||
NonPagedPoolLimit uint32
|
||||
MinimumWorkingSetSize uint32
|
||||
MaximumWorkingSetSize uint32
|
||||
PagefileLimit uint32
|
||||
TimeLimit int64
|
||||
}
|
||||
|
||||
var (
|
||||
secur32 = windows.NewLazyDLL("secur32.dll")
|
||||
procLsaRegisterLogonProcess = secur32.NewProc("LsaRegisterLogonProcess")
|
||||
procLsaLookupAuthenticationPackage = secur32.NewProc("LsaLookupAuthenticationPackage")
|
||||
procLsaLogonUser = secur32.NewProc("LsaLogonUser")
|
||||
procLsaFreeReturnBuffer = secur32.NewProc("LsaFreeReturnBuffer")
|
||||
procLsaDeregisterLogonProcess = secur32.NewProc("LsaDeregisterLogonProcess")
|
||||
procTranslateNameW = secur32.NewProc("TranslateNameW")
|
||||
)
|
||||
|
||||
// newLsaString creates an LsaString from a Go string
|
||||
func newLsaString(s string) lsaString {
|
||||
b := append([]byte(s), 0)
|
||||
return lsaString{
|
||||
Length: uint16(len(s)),
|
||||
MaximumLength: uint16(len(b)),
|
||||
Buffer: &b[0],
|
||||
}
|
||||
}
|
||||
|
||||
// generateS4UUserToken creates a Windows token using S4U authentication
|
||||
// This is the exact approach OpenSSH for Windows uses for public key authentication
|
||||
func generateS4UUserToken(username, domain string) (windows.Handle, error) {
|
||||
userCpn := buildUserCpn(username, domain)
|
||||
|
||||
pd := NewPrivilegeDropper()
|
||||
isDomainUser := !pd.isLocalUser(domain)
|
||||
|
||||
lsaHandle, err := initializeLsaConnection()
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
defer cleanupLsaConnection(lsaHandle)
|
||||
|
||||
authPackageId, err := lookupAuthenticationPackage(lsaHandle, isDomainUser)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
logonInfo, logonInfoSize, err := prepareS4ULogonStructure(username, domain, isDomainUser)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return performS4ULogon(lsaHandle, authPackageId, logonInfo, logonInfoSize, userCpn, isDomainUser)
|
||||
}
|
||||
|
||||
// buildUserCpn constructs the user principal name
|
||||
func buildUserCpn(username, domain string) string {
|
||||
if domain != "" && domain != "." {
|
||||
return fmt.Sprintf(`%s\%s`, domain, username)
|
||||
}
|
||||
return username
|
||||
}
|
||||
|
||||
// initializeLsaConnection establishes connection to LSA
|
||||
func initializeLsaConnection() (windows.Handle, error) {
|
||||
|
||||
processName := newLsaString("NetBird")
|
||||
var mode uint32
|
||||
var lsaHandle windows.Handle
|
||||
ret, _, _ := procLsaRegisterLogonProcess.Call(
|
||||
uintptr(unsafe.Pointer(&processName)),
|
||||
uintptr(unsafe.Pointer(&lsaHandle)),
|
||||
uintptr(unsafe.Pointer(&mode)),
|
||||
)
|
||||
if ret != StatusSuccess {
|
||||
return 0, fmt.Errorf("LsaRegisterLogonProcess: 0x%x", ret)
|
||||
}
|
||||
|
||||
return lsaHandle, nil
|
||||
}
|
||||
|
||||
// cleanupLsaConnection closes the LSA connection
|
||||
func cleanupLsaConnection(lsaHandle windows.Handle) {
|
||||
if ret, _, _ := procLsaDeregisterLogonProcess.Call(uintptr(lsaHandle)); ret != StatusSuccess {
|
||||
log.Debugf("LsaDeregisterLogonProcess failed: 0x%x", ret)
|
||||
}
|
||||
}
|
||||
|
||||
// lookupAuthenticationPackage finds the correct authentication package
|
||||
func lookupAuthenticationPackage(lsaHandle windows.Handle, isDomainUser bool) (uint32, error) {
|
||||
var authPackageName lsaString
|
||||
if isDomainUser {
|
||||
authPackageName = newLsaString(MicrosoftKerberosNameA)
|
||||
} else {
|
||||
authPackageName = newLsaString(Msv10packagename)
|
||||
}
|
||||
|
||||
var authPackageId uint32
|
||||
ret, _, _ := procLsaLookupAuthenticationPackage.Call(
|
||||
uintptr(lsaHandle),
|
||||
uintptr(unsafe.Pointer(&authPackageName)),
|
||||
uintptr(unsafe.Pointer(&authPackageId)),
|
||||
)
|
||||
if ret != StatusSuccess {
|
||||
return 0, fmt.Errorf("LsaLookupAuthenticationPackage: 0x%x", ret)
|
||||
}
|
||||
|
||||
return authPackageId, nil
|
||||
}
|
||||
|
||||
// lookupPrincipalName converts DOMAIN\username to username@domain.fqdn (UPN format)
|
||||
func lookupPrincipalName(username, domain string) (string, error) {
|
||||
samAccountName := fmt.Sprintf(`%s\%s`, domain, username)
|
||||
samAccountNameUtf16, err := windows.UTF16PtrFromString(samAccountName)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("convert SAM account name to UTF-16: %w", err)
|
||||
}
|
||||
|
||||
upnBuf := make([]uint16, maxUPNLen+1)
|
||||
upnSize := uint32(len(upnBuf))
|
||||
|
||||
ret, _, _ := procTranslateNameW.Call(
|
||||
uintptr(unsafe.Pointer(samAccountNameUtf16)),
|
||||
uintptr(NameSamCompatible),
|
||||
uintptr(NameUserPrincipal),
|
||||
uintptr(unsafe.Pointer(&upnBuf[0])),
|
||||
uintptr(unsafe.Pointer(&upnSize)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
upn := windows.UTF16ToString(upnBuf[:upnSize])
|
||||
log.Debugf("Translated %s to explicit UPN: %s", samAccountName, upn)
|
||||
return upn, nil
|
||||
}
|
||||
|
||||
upnSize = uint32(len(upnBuf))
|
||||
ret, _, _ = procTranslateNameW.Call(
|
||||
uintptr(unsafe.Pointer(samAccountNameUtf16)),
|
||||
uintptr(NameSamCompatible),
|
||||
uintptr(NameCanonical),
|
||||
uintptr(unsafe.Pointer(&upnBuf[0])),
|
||||
uintptr(unsafe.Pointer(&upnSize)),
|
||||
)
|
||||
|
||||
if ret != 0 {
|
||||
canonical := windows.UTF16ToString(upnBuf[:upnSize])
|
||||
slashIdx := strings.IndexByte(canonical, '/')
|
||||
if slashIdx > 0 {
|
||||
fqdn := canonical[:slashIdx]
|
||||
upn := fmt.Sprintf("%s@%s", username, fqdn)
|
||||
log.Debugf("Translated %s to implicit UPN: %s (from canonical: %s)", samAccountName, upn, canonical)
|
||||
return upn, nil
|
||||
}
|
||||
}
|
||||
|
||||
log.Debugf("Could not translate %s to UPN, using SAM format", samAccountName)
|
||||
return samAccountName, nil
|
||||
}
|
||||
|
||||
// prepareS4ULogonStructure creates the appropriate S4U logon structure
|
||||
func prepareS4ULogonStructure(username, domain string, isDomainUser bool) (unsafe.Pointer, uintptr, error) {
|
||||
if isDomainUser {
|
||||
return prepareDomainS4ULogon(username, domain)
|
||||
}
|
||||
return prepareLocalS4ULogon(username)
|
||||
}
|
||||
|
||||
// prepareDomainS4ULogon creates S4U logon structure for domain users
|
||||
func prepareDomainS4ULogon(username, domain string) (unsafe.Pointer, uintptr, error) {
|
||||
upn, err := lookupPrincipalName(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("lookup principal name: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("using KerbS4ULogon for domain user with UPN: %s", upn)
|
||||
|
||||
upnUtf16, err := windows.UTF16FromString(upn)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf(convertUsernameError, err)
|
||||
}
|
||||
|
||||
structSize := unsafe.Sizeof(kerbS4ULogon{})
|
||||
upnByteSize := len(upnUtf16) * 2
|
||||
logonInfoSize := structSize + uintptr(upnByteSize)
|
||||
|
||||
buffer := make([]byte, logonInfoSize)
|
||||
logonInfo := unsafe.Pointer(&buffer[0])
|
||||
|
||||
s4uLogon := (*kerbS4ULogon)(logonInfo)
|
||||
s4uLogon.MessageType = KerbS4ULogonType
|
||||
s4uLogon.Flags = 0
|
||||
|
||||
upnOffset := structSize
|
||||
upnBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + upnOffset))
|
||||
copy((*[1025]uint16)(unsafe.Pointer(upnBuffer))[:len(upnUtf16)], upnUtf16)
|
||||
|
||||
s4uLogon.ClientUpn = unicodeString{
|
||||
Length: uint16((len(upnUtf16) - 1) * 2),
|
||||
MaximumLength: uint16(len(upnUtf16) * 2),
|
||||
Buffer: upnBuffer,
|
||||
}
|
||||
s4uLogon.ClientRealm = unicodeString{}
|
||||
|
||||
return logonInfo, logonInfoSize, nil
|
||||
}
|
||||
|
||||
// prepareLocalS4ULogon creates S4U logon structure for local users
|
||||
func prepareLocalS4ULogon(username string) (unsafe.Pointer, uintptr, error) {
|
||||
log.Debugf("using Msv1_0S4ULogon for local user: %s", username)
|
||||
|
||||
usernameUtf16, err := windows.UTF16FromString(username)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf(convertUsernameError, err)
|
||||
}
|
||||
|
||||
domainUtf16, err := windows.UTF16FromString(".")
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf(convertDomainError, err)
|
||||
}
|
||||
|
||||
structSize := unsafe.Sizeof(msv10s4ulogon{})
|
||||
usernameByteSize := len(usernameUtf16) * 2
|
||||
domainByteSize := len(domainUtf16) * 2
|
||||
logonInfoSize := structSize + uintptr(usernameByteSize) + uintptr(domainByteSize)
|
||||
|
||||
buffer := make([]byte, logonInfoSize)
|
||||
logonInfo := unsafe.Pointer(&buffer[0])
|
||||
|
||||
s4uLogon := (*msv10s4ulogon)(logonInfo)
|
||||
s4uLogon.MessageType = Msv10s4ulogontype
|
||||
s4uLogon.Flags = 0x0
|
||||
|
||||
usernameOffset := structSize
|
||||
usernameBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + usernameOffset))
|
||||
copy((*[256]uint16)(unsafe.Pointer(usernameBuffer))[:len(usernameUtf16)], usernameUtf16)
|
||||
|
||||
s4uLogon.UserPrincipalName = unicodeString{
|
||||
Length: uint16((len(usernameUtf16) - 1) * 2),
|
||||
MaximumLength: uint16(len(usernameUtf16) * 2),
|
||||
Buffer: usernameBuffer,
|
||||
}
|
||||
|
||||
domainOffset := usernameOffset + uintptr(usernameByteSize)
|
||||
domainBuffer := (*uint16)(unsafe.Pointer(uintptr(logonInfo) + domainOffset))
|
||||
copy((*[16]uint16)(unsafe.Pointer(domainBuffer))[:len(domainUtf16)], domainUtf16)
|
||||
|
||||
s4uLogon.DomainName = unicodeString{
|
||||
Length: uint16((len(domainUtf16) - 1) * 2),
|
||||
MaximumLength: uint16(len(domainUtf16) * 2),
|
||||
Buffer: domainBuffer,
|
||||
}
|
||||
|
||||
return logonInfo, logonInfoSize, nil
|
||||
}
|
||||
|
||||
// performS4ULogon executes the S4U logon operation
|
||||
func performS4ULogon(lsaHandle windows.Handle, authPackageId uint32, logonInfo unsafe.Pointer, logonInfoSize uintptr, userCpn string, isDomainUser bool) (windows.Handle, error) {
|
||||
var tokenSource tokenSource
|
||||
copy(tokenSource.SourceName[:], "netbird")
|
||||
if ret, _, _ := procAllocateLocallyUniqueId.Call(uintptr(unsafe.Pointer(&tokenSource.SourceIdentifier))); ret == 0 {
|
||||
log.Debugf("AllocateLocallyUniqueId failed")
|
||||
}
|
||||
|
||||
originName := newLsaString("netbird")
|
||||
|
||||
var profile uintptr
|
||||
var profileSize uint32
|
||||
var logonId windows.LUID
|
||||
var token windows.Handle
|
||||
var quotas quotaLimits
|
||||
var subStatus int32
|
||||
|
||||
ret, _, _ := procLsaLogonUser.Call(
|
||||
uintptr(lsaHandle),
|
||||
uintptr(unsafe.Pointer(&originName)),
|
||||
logon32LogonNetwork,
|
||||
uintptr(authPackageId),
|
||||
uintptr(logonInfo),
|
||||
logonInfoSize,
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&tokenSource)),
|
||||
uintptr(unsafe.Pointer(&profile)),
|
||||
uintptr(unsafe.Pointer(&profileSize)),
|
||||
uintptr(unsafe.Pointer(&logonId)),
|
||||
uintptr(unsafe.Pointer(&token)),
|
||||
uintptr(unsafe.Pointer("as)),
|
||||
uintptr(unsafe.Pointer(&subStatus)),
|
||||
)
|
||||
|
||||
if profile != 0 {
|
||||
if ret, _, _ := procLsaFreeReturnBuffer.Call(profile); ret != StatusSuccess {
|
||||
log.Debugf("LsaFreeReturnBuffer failed: 0x%x", ret)
|
||||
}
|
||||
}
|
||||
|
||||
if ret != StatusSuccess {
|
||||
return 0, fmt.Errorf("LsaLogonUser S4U for %s: NTSTATUS=0x%x, SubStatus=0x%x", userCpn, ret, subStatus)
|
||||
}
|
||||
|
||||
log.Debugf("created S4U %s token for user %s",
|
||||
map[bool]string{true: "domain", false: "local"}[isDomainUser], userCpn)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// createToken implements NetBird trust-based authentication using S4U
|
||||
func (pd *PrivilegeDropper) createToken(username, domain string) (windows.Handle, error) {
|
||||
fullUsername := buildUserCpn(username, domain)
|
||||
|
||||
if err := userExists(fullUsername, username, domain); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
isLocalUser := pd.isLocalUser(domain)
|
||||
|
||||
if isLocalUser {
|
||||
return pd.authenticateLocalUser(username, fullUsername)
|
||||
}
|
||||
return pd.authenticateDomainUser(username, domain, fullUsername)
|
||||
}
|
||||
|
||||
// userExists checks if the target useVerifier exists on the system
|
||||
func userExists(fullUsername, username, domain string) error {
|
||||
if _, err := lookupUser(fullUsername); err != nil {
|
||||
log.Debugf("User %s not found: %v", fullUsername, err)
|
||||
if domain != "" && domain != "." {
|
||||
_, err = lookupUser(username)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("target user %s not found: %w", fullUsername, err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// isLocalUser determines if this is a local user vs domain user
|
||||
func (pd *PrivilegeDropper) isLocalUser(domain string) bool {
|
||||
hostname, err := os.Hostname()
|
||||
if err != nil {
|
||||
hostname = "localhost"
|
||||
}
|
||||
|
||||
return domain == "" || domain == "." ||
|
||||
strings.EqualFold(domain, hostname)
|
||||
}
|
||||
|
||||
// authenticateLocalUser handles authentication for local users
|
||||
func (pd *PrivilegeDropper) authenticateLocalUser(username, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for local user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, ".")
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for local user %s: %w", fullUsername, err)
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// authenticateDomainUser handles authentication for domain users
|
||||
func (pd *PrivilegeDropper) authenticateDomainUser(username, domain, fullUsername string) (windows.Handle, error) {
|
||||
log.Debugf("using S4U authentication for domain user %s", fullUsername)
|
||||
token, err := generateS4UUserToken(username, domain)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("S4U authentication for domain user %s: %w", fullUsername, err)
|
||||
}
|
||||
log.Debugf("Successfully created S4U token for domain user %s", fullUsername)
|
||||
return token, nil
|
||||
}
|
||||
|
||||
// CreateWindowsProcessAsUser creates a process as user with safe argument passing (for SFTP and executables).
|
||||
// The caller must close the returned token handle after starting the process.
|
||||
func (pd *PrivilegeDropper) CreateWindowsProcessAsUser(ctx context.Context, executablePath string, args []string, username, domain, workingDir string) (*exec.Cmd, windows.Token, error) {
|
||||
token, err := pd.createToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("user authentication: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(token); err != nil {
|
||||
log.Debugf("close impersonation token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmd, primaryToken, err := pd.createProcessWithToken(ctx, windows.Token(token), executablePath, args, workingDir)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
return cmd, primaryToken, nil
|
||||
}
|
||||
|
||||
// createProcessWithToken creates process with the specified token and executable path.
|
||||
// The caller must close the returned token handle after starting the process.
|
||||
func (pd *PrivilegeDropper) createProcessWithToken(ctx context.Context, sourceToken windows.Token, executablePath string, args []string, workingDir string) (*exec.Cmd, windows.Token, error) {
|
||||
cmd := exec.CommandContext(ctx, executablePath, args[1:]...)
|
||||
cmd.Dir = workingDir
|
||||
|
||||
var primaryToken windows.Token
|
||||
err := windows.DuplicateTokenEx(
|
||||
sourceToken,
|
||||
windows.TOKEN_ALL_ACCESS,
|
||||
nil,
|
||||
windows.SecurityIdentification,
|
||||
windows.TokenPrimary,
|
||||
&primaryToken,
|
||||
)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("duplicate token to primary token: %w", err)
|
||||
}
|
||||
|
||||
cmd.SysProcAttr = &syscall.SysProcAttr{
|
||||
Token: syscall.Token(primaryToken),
|
||||
}
|
||||
|
||||
return cmd, primaryToken, nil
|
||||
}
|
||||
|
||||
// createSuCommand creates a command using su -l -c for privilege switching (Windows stub)
|
||||
func (s *Server) createSuCommand(ssh.Session, *user.User, bool) (*exec.Cmd, error) {
|
||||
return nil, fmt.Errorf("su command not available on Windows")
|
||||
}
|
||||
629
client/ssh/server/jwt_test.go
Normal file
629
client/ssh/server/jwt_test.go
Normal file
@@ -0,0 +1,629 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
)
|
||||
|
||||
func TestJWTEnforcement(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT enforcement tests in short mode")
|
||||
}
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("blocks_without_jwt", func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
if err != nil {
|
||||
t.Logf("Detection failed: %v", err)
|
||||
}
|
||||
t.Logf("Detected server type: %s", serverType)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
|
||||
})
|
||||
|
||||
t.Run("allows_when_disabled", func(t *testing.T) {
|
||||
serverConfigNoJWT := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
serverNoJWT := New(serverConfigNoJWT)
|
||||
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
|
||||
serverNoJWT.SetAllowRootLogin(true)
|
||||
|
||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||
defer require.NoError(t, serverNoJWT.Stop())
|
||||
|
||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||
require.NoError(t, err)
|
||||
portNoJWT, err := strconv.Atoi(portStrNoJWT)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
|
||||
assert.False(t, serverType.RequiresJWT())
|
||||
|
||||
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64RawURLEncode(n),
|
||||
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func base64RawURLEncode(data []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// generateValidJWT creates a valid JWT token for testing
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
|
||||
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
|
||||
t.Helper()
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
ctx := context.Background()
|
||||
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
}
|
||||
|
||||
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
|
||||
func TestJWTDetection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT detection test in short mode")
|
||||
}
|
||||
|
||||
jwksServer, _, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
|
||||
assert.True(t, serverType.RequiresJWT())
|
||||
}
|
||||
|
||||
func TestJWTFailClose(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT fail-close tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
tokenClaims jwt.MapClaims
|
||||
}{
|
||||
{
|
||||
name: "blocks_token_missing_iat",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_sub",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_iss",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_aud",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_issuer",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": "wrong-issuer",
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_audience",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": "wrong-audience",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_expired_token",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_exceeding_max_age",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
MaxTokenAge: 3600,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.Password(tokenString),
|
||||
},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
assert.Error(t, err, "Authentication should fail (fail-close)")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
|
||||
func TestJWTAuthentication(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT authentication tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
token string
|
||||
wantAuthOK bool
|
||||
setupServer func(*Server)
|
||||
testOperation func(*testing.T, *cryptossh.Client, string) error
|
||||
wantOpSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "allows_shell_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
return session.Shell()
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "rejects_invalid_token",
|
||||
token: "invalid",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_shell_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_command_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("ls")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_sftp_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
return session.RequestSubsystem("sftp")
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_sftp_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
err = session.RequestSubsystem("sftp")
|
||||
if err == nil {
|
||||
err = session.Wait()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_port_forward_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowRemotePortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_port_forward_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowLocalPortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// TODO: Skip port forwarding tests on Windows - user switching not supported
|
||||
// These features are tested on Linux/Unix platforms
|
||||
if runtime.GOOS == "windows" &&
|
||||
(tc.name == "allows_port_forward_with_jwt" ||
|
||||
tc.name == "blocks_port_forward_without_jwt") {
|
||||
t.Skip("Skipping port forwarding test on Windows - covered by Linux tests")
|
||||
}
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
if tc.setupServer != nil {
|
||||
tc.setupServer(server)
|
||||
}
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
var authMethods []cryptossh.AuthMethod
|
||||
if tc.token == "valid" {
|
||||
token := generateValidJWT(t, privateKey, issuer, audience)
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(token),
|
||||
}
|
||||
} else if tc.token == "invalid" {
|
||||
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(invalidToken),
|
||||
}
|
||||
}
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if tc.wantAuthOK {
|
||||
require.NoError(t, err, "JWT authentication should succeed")
|
||||
} else if err != nil {
|
||||
t.Logf("Connection failed as expected: %v", err)
|
||||
return
|
||||
}
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err = tc.testOperation(t, conn, serverAddr)
|
||||
if tc.wantOpSuccess {
|
||||
require.NoError(t, err, "Operation should succeed")
|
||||
} else {
|
||||
assert.Error(t, err, "Operation should fail")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
386
client/ssh/server/port_forwarding.go
Normal file
386
client/ssh/server/port_forwarding.go
Normal file
@@ -0,0 +1,386 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// SessionKey uniquely identifies an SSH session
|
||||
type SessionKey string
|
||||
|
||||
// ConnectionKey uniquely identifies a port forwarding connection within a session
|
||||
type ConnectionKey string
|
||||
|
||||
// ForwardKey uniquely identifies a port forwarding listener
|
||||
type ForwardKey string
|
||||
|
||||
// tcpipForwardMsg represents the structure for tcpip-forward SSH requests
|
||||
type tcpipForwardMsg struct {
|
||||
Host string
|
||||
Port uint32
|
||||
}
|
||||
|
||||
// SetAllowLocalPortForwarding configures local port forwarding
|
||||
func (s *Server) SetAllowLocalPortForwarding(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowLocalPortForwarding = allow
|
||||
}
|
||||
|
||||
// SetAllowRemotePortForwarding configures remote port forwarding
|
||||
func (s *Server) SetAllowRemotePortForwarding(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowRemotePortForwarding = allow
|
||||
}
|
||||
|
||||
// configurePortForwarding sets up port forwarding callbacks
|
||||
func (s *Server) configurePortForwarding(server *ssh.Server) {
|
||||
allowLocal := s.allowLocalPortForwarding
|
||||
allowRemote := s.allowRemotePortForwarding
|
||||
|
||||
server.LocalPortForwardingCallback = func(ctx ssh.Context, dstHost string, dstPort uint32) bool {
|
||||
if !allowLocal {
|
||||
log.Warnf("local port forwarding denied for %s from %s: disabled by configuration",
|
||||
net.JoinHostPort(dstHost, fmt.Sprintf("%d", dstPort)), ctx.RemoteAddr())
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", dstPort); err != nil {
|
||||
log.Warnf("local port forwarding denied for %s:%d from %s: %v", dstHost, dstPort, ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debugf("local port forwarding allowed: %s:%d", dstHost, dstPort)
|
||||
return true
|
||||
}
|
||||
|
||||
server.ReversePortForwardingCallback = func(ctx ssh.Context, bindHost string, bindPort uint32) bool {
|
||||
if !allowRemote {
|
||||
log.Warnf("remote port forwarding denied for %s from %s: disabled by configuration",
|
||||
net.JoinHostPort(bindHost, fmt.Sprintf("%d", bindPort)), ctx.RemoteAddr())
|
||||
return false
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "remote", bindPort); err != nil {
|
||||
log.Warnf("remote port forwarding denied for %s:%d from %s: %v", bindHost, bindPort, ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Debugf("remote port forwarding allowed: %s:%d", bindHost, bindPort)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("SSH server configured with local_forwarding=%v, remote_forwarding=%v", allowLocal, allowRemote)
|
||||
}
|
||||
|
||||
// checkPortForwardingPrivileges validates privilege requirements for port forwarding operations.
|
||||
// Returns nil if allowed, error if denied.
|
||||
func (s *Server) checkPortForwardingPrivileges(ctx ssh.Context, forwardType string, port uint32) error {
|
||||
if ctx == nil {
|
||||
return fmt.Errorf("%s port forwarding denied: no context", forwardType)
|
||||
}
|
||||
|
||||
username := ctx.User()
|
||||
remoteAddr := "unknown"
|
||||
if ctx.RemoteAddr() != nil {
|
||||
remoteAddr = ctx.RemoteAddr().String()
|
||||
}
|
||||
|
||||
logger := log.WithFields(log.Fields{"user": username, "remote": remoteAddr, "port": port})
|
||||
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: username,
|
||||
FeatureSupportsUserSwitch: false,
|
||||
FeatureName: forwardType + " port forwarding",
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
return result.Error
|
||||
}
|
||||
|
||||
logger.Debugf("%s port forwarding allowed: user %s validated (port %d)",
|
||||
forwardType, result.User.Username, port)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// tcpipForwardHandler handles tcpip-forward requests for remote port forwarding.
|
||||
func (s *Server) tcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
|
||||
if !s.isRemotePortForwardingAllowed() {
|
||||
logger.Warnf("tcpip-forward request denied: remote port forwarding disabled")
|
||||
return false, nil
|
||||
}
|
||||
|
||||
payload, err := s.parseTcpipForwardRequest(req)
|
||||
if err != nil {
|
||||
logger.Errorf("tcpip-forward unmarshal error: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "tcpip-forward", payload.Port); err != nil {
|
||||
logger.Warnf("tcpip-forward denied: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
logger.Debugf("tcpip-forward request: %s:%d", payload.Host, payload.Port)
|
||||
|
||||
sshConn, err := s.getSSHConnection(ctx)
|
||||
if err != nil {
|
||||
logger.Warnf("tcpip-forward request denied: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
return s.setupDirectForward(ctx, logger, sshConn, payload)
|
||||
}
|
||||
|
||||
// cancelTcpipForwardHandler handles cancel-tcpip-forward requests.
|
||||
func (s *Server) cancelTcpipForwardHandler(ctx ssh.Context, _ *ssh.Server, req *cryptossh.Request) (bool, []byte) {
|
||||
logger := s.getRequestLogger(ctx)
|
||||
|
||||
var payload tcpipForwardMsg
|
||||
if err := cryptossh.Unmarshal(req.Payload, &payload); err != nil {
|
||||
logger.Errorf("cancel-tcpip-forward unmarshal error: %v", err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
if s.removeRemoteForwardListener(key) {
|
||||
logger.Infof("remote port forwarding cancelled: %s:%d", payload.Host, payload.Port)
|
||||
return true, nil
|
||||
}
|
||||
|
||||
logger.Warnf("cancel-tcpip-forward failed: no listener found for %s:%d", payload.Host, payload.Port)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// handleRemoteForwardListener handles incoming connections for remote port forwarding.
|
||||
func (s *Server) handleRemoteForwardListener(ctx ssh.Context, ln net.Listener, host string, port uint32) {
|
||||
log.Debugf("starting remote forward listener handler for %s:%d", host, port)
|
||||
|
||||
defer func() {
|
||||
log.Debugf("cleaning up remote forward listener for %s:%d", host, port)
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("remote forward listener close error: %v", err)
|
||||
} else {
|
||||
log.Debugf("remote forward listener closed successfully for %s:%d", host, port)
|
||||
}
|
||||
}()
|
||||
|
||||
acceptChan := make(chan acceptResult, 1)
|
||||
|
||||
go func() {
|
||||
for {
|
||||
conn, err := ln.Accept()
|
||||
select {
|
||||
case acceptChan <- acceptResult{conn: conn, err: err}:
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}()
|
||||
|
||||
for {
|
||||
select {
|
||||
case result := <-acceptChan:
|
||||
if result.err != nil {
|
||||
log.Debugf("remote forward accept error: %v", result.err)
|
||||
return
|
||||
}
|
||||
go s.handleRemoteForwardConnection(ctx, result.conn, host, port)
|
||||
case <-ctx.Done():
|
||||
log.Debugf("remote forward listener shutting down due to context cancellation for %s:%d", host, port)
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// getRequestLogger creates a logger with user and remote address context
|
||||
func (s *Server) getRequestLogger(ctx ssh.Context) *log.Entry {
|
||||
remoteAddr := "unknown"
|
||||
username := "unknown"
|
||||
if ctx != nil {
|
||||
if ctx.RemoteAddr() != nil {
|
||||
remoteAddr = ctx.RemoteAddr().String()
|
||||
}
|
||||
username = ctx.User()
|
||||
}
|
||||
return log.WithFields(log.Fields{"user": username, "remote": remoteAddr})
|
||||
}
|
||||
|
||||
// isRemotePortForwardingAllowed checks if remote port forwarding is enabled
|
||||
func (s *Server) isRemotePortForwardingAllowed() bool {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
return s.allowRemotePortForwarding
|
||||
}
|
||||
|
||||
// parseTcpipForwardRequest parses the SSH request payload
|
||||
func (s *Server) parseTcpipForwardRequest(req *cryptossh.Request) (*tcpipForwardMsg, error) {
|
||||
var payload tcpipForwardMsg
|
||||
err := cryptossh.Unmarshal(req.Payload, &payload)
|
||||
return &payload, err
|
||||
}
|
||||
|
||||
// getSSHConnection extracts SSH connection from context
|
||||
func (s *Server) getSSHConnection(ctx ssh.Context) (*cryptossh.ServerConn, error) {
|
||||
if ctx == nil {
|
||||
return nil, fmt.Errorf("no context")
|
||||
}
|
||||
sshConnValue := ctx.Value(ssh.ContextKeyConn)
|
||||
if sshConnValue == nil {
|
||||
return nil, fmt.Errorf("no SSH connection in context")
|
||||
}
|
||||
sshConn, ok := sshConnValue.(*cryptossh.ServerConn)
|
||||
if !ok || sshConn == nil {
|
||||
return nil, fmt.Errorf("invalid SSH connection in context")
|
||||
}
|
||||
return sshConn, nil
|
||||
}
|
||||
|
||||
// setupDirectForward sets up a direct port forward
|
||||
func (s *Server) setupDirectForward(ctx ssh.Context, logger *log.Entry, sshConn *cryptossh.ServerConn, payload *tcpipForwardMsg) (bool, []byte) {
|
||||
bindAddr := net.JoinHostPort(payload.Host, strconv.FormatUint(uint64(payload.Port), 10))
|
||||
|
||||
ln, err := net.Listen("tcp", bindAddr)
|
||||
if err != nil {
|
||||
logger.Errorf("tcpip-forward listen failed on %s: %v", bindAddr, err)
|
||||
return false, nil
|
||||
}
|
||||
|
||||
actualPort := payload.Port
|
||||
if payload.Port == 0 {
|
||||
tcpAddr := ln.Addr().(*net.TCPAddr)
|
||||
actualPort = uint32(tcpAddr.Port)
|
||||
logger.Debugf("tcpip-forward allocated port %d for %s", actualPort, payload.Host)
|
||||
}
|
||||
|
||||
key := ForwardKey(fmt.Sprintf("%s:%d", payload.Host, payload.Port))
|
||||
s.storeRemoteForwardListener(key, ln)
|
||||
|
||||
s.markConnectionActivePortForward(sshConn, ctx.User(), ctx.RemoteAddr().String())
|
||||
go s.handleRemoteForwardListener(ctx, ln, payload.Host, actualPort)
|
||||
|
||||
response := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(response, actualPort)
|
||||
|
||||
logger.Infof("remote port forwarding established: %s:%d", payload.Host, actualPort)
|
||||
return true, response
|
||||
}
|
||||
|
||||
// acceptResult holds the result of a listener Accept() call
|
||||
type acceptResult struct {
|
||||
conn net.Conn
|
||||
err error
|
||||
}
|
||||
|
||||
// handleRemoteForwardConnection handles a single remote port forwarding connection
|
||||
func (s *Server) handleRemoteForwardConnection(ctx ssh.Context, conn net.Conn, host string, port uint32) {
|
||||
sessionKey := s.findSessionKeyByContext(ctx)
|
||||
connID := fmt.Sprintf("pf-%s->%s:%d", conn.RemoteAddr(), host, port)
|
||||
logger := log.WithFields(log.Fields{
|
||||
"session": sessionKey,
|
||||
"conn": connID,
|
||||
})
|
||||
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
sshConn := ctx.Value(ssh.ContextKeyConn).(*cryptossh.ServerConn)
|
||||
if sshConn == nil {
|
||||
logger.Debugf("remote forward: no SSH connection in context")
|
||||
return
|
||||
}
|
||||
|
||||
remoteAddr, ok := conn.RemoteAddr().(*net.TCPAddr)
|
||||
if !ok {
|
||||
logger.Warnf("remote forward: non-TCP connection type: %T", conn.RemoteAddr())
|
||||
return
|
||||
}
|
||||
|
||||
channel, err := s.openForwardChannel(sshConn, host, port, remoteAddr, logger)
|
||||
if err != nil {
|
||||
logger.Debugf("open forward channel: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
s.proxyForwardConnection(ctx, logger, conn, channel)
|
||||
}
|
||||
|
||||
// openForwardChannel creates an SSH forwarded-tcpip channel
|
||||
func (s *Server) openForwardChannel(sshConn *cryptossh.ServerConn, host string, port uint32, remoteAddr *net.TCPAddr, logger *log.Entry) (cryptossh.Channel, error) {
|
||||
logger.Tracef("opening forwarded-tcpip channel for %s:%d", host, port)
|
||||
|
||||
payload := struct {
|
||||
ConnectedAddress string
|
||||
ConnectedPort uint32
|
||||
OriginatorAddress string
|
||||
OriginatorPort uint32
|
||||
}{
|
||||
ConnectedAddress: host,
|
||||
ConnectedPort: port,
|
||||
OriginatorAddress: remoteAddr.IP.String(),
|
||||
OriginatorPort: uint32(remoteAddr.Port),
|
||||
}
|
||||
|
||||
channel, reqs, err := sshConn.OpenChannel("forwarded-tcpip", cryptossh.Marshal(&payload))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("open SSH channel: %w", err)
|
||||
}
|
||||
|
||||
go cryptossh.DiscardRequests(reqs)
|
||||
return channel, nil
|
||||
}
|
||||
|
||||
// proxyForwardConnection handles bidirectional data transfer between connection and SSH channel
|
||||
func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn net.Conn, channel cryptossh.Channel) {
|
||||
done := make(chan struct{}, 2)
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(channel, conn); err != nil {
|
||||
logger.Debugf("copy error (conn->channel): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
go func() {
|
||||
if _, err := io.Copy(conn, channel); err != nil {
|
||||
logger.Debugf("copy error (channel->conn): %v", err)
|
||||
}
|
||||
done <- struct{}{}
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
// First copy finished, wait for second copy or context cancellation
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
logger.Debugf("session ended, closing connections")
|
||||
case <-done:
|
||||
}
|
||||
}
|
||||
|
||||
if err := channel.Close(); err != nil {
|
||||
logger.Debugf("channel close error: %v", err)
|
||||
}
|
||||
if err := conn.Close(); err != nil {
|
||||
logger.Debugf("connection close error: %v", err)
|
||||
}
|
||||
}
|
||||
712
client/ssh/server/server.go
Normal file
712
client/ssh/server/server.go
Normal file
@@ -0,0 +1,712 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"golang.org/x/exp/maps"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/shared/auth"
|
||||
"github.com/netbirdio/netbird/shared/auth/jwt"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
const DefaultSSHPort = 22
|
||||
|
||||
// InternalSSHPort is the port SSH server listens on and is redirected to
|
||||
const InternalSSHPort = 22022
|
||||
|
||||
const (
|
||||
errWriteSession = "write session error: %v"
|
||||
errExitSession = "exit session error: %v"
|
||||
|
||||
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
||||
|
||||
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
||||
DefaultJWTMaxTokenAge = 5 * 60
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPrivilegedUserDisabled = errors.New(msgPrivilegedUserDisabled)
|
||||
ErrUserNotFound = errors.New("user not found")
|
||||
)
|
||||
|
||||
// PrivilegedUserError represents an error when privileged user login is disabled
|
||||
type PrivilegedUserError struct {
|
||||
Username string
|
||||
}
|
||||
|
||||
func (e *PrivilegedUserError) Error() string {
|
||||
return fmt.Sprintf("%s for user: %s", msgPrivilegedUserDisabled, e.Username)
|
||||
}
|
||||
|
||||
func (e *PrivilegedUserError) Is(target error) bool {
|
||||
return target == ErrPrivilegedUserDisabled
|
||||
}
|
||||
|
||||
// UserNotFoundError represents an error when a user cannot be found
|
||||
type UserNotFoundError struct {
|
||||
Username string
|
||||
Cause error
|
||||
}
|
||||
|
||||
func (e *UserNotFoundError) Error() string {
|
||||
if e.Cause != nil {
|
||||
return fmt.Sprintf("user %s not found: %v", e.Username, e.Cause)
|
||||
}
|
||||
return fmt.Sprintf("user %s not found", e.Username)
|
||||
}
|
||||
|
||||
func (e *UserNotFoundError) Is(target error) bool {
|
||||
return target == ErrUserNotFound
|
||||
}
|
||||
|
||||
func (e *UserNotFoundError) Unwrap() error {
|
||||
return e.Cause
|
||||
}
|
||||
|
||||
// logSessionExitError logs session exit errors, ignoring EOF (normal close) errors
|
||||
func logSessionExitError(logger *log.Entry, err error) {
|
||||
if err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Warnf(errExitSession, err)
|
||||
}
|
||||
}
|
||||
|
||||
// safeLogCommand returns a safe representation of the command for logging
|
||||
func safeLogCommand(cmd []string) string {
|
||||
if len(cmd) == 0 {
|
||||
return "<interactive shell>"
|
||||
}
|
||||
if len(cmd) == 1 {
|
||||
return cmd[0]
|
||||
}
|
||||
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
||||
}
|
||||
|
||||
type sshConnectionState struct {
|
||||
hasActivePortForward bool
|
||||
username string
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
type authKey string
|
||||
|
||||
func newAuthKey(username string, remoteAddr net.Addr) authKey {
|
||||
return authKey(fmt.Sprintf("%s@%s", username, remoteAddr.String()))
|
||||
}
|
||||
|
||||
type Server struct {
|
||||
sshServer *ssh.Server
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
sessions map[SessionKey]ssh.Session
|
||||
sessionCancels map[ConnectionKey]context.CancelFunc
|
||||
sessionJWTUsers map[SessionKey]string
|
||||
pendingAuthJWT map[authKey]string
|
||||
|
||||
allowLocalPortForwarding bool
|
||||
allowRemotePortForwarding bool
|
||||
allowRootLogin bool
|
||||
allowSFTP bool
|
||||
jwtEnabled bool
|
||||
|
||||
netstackNet *netstack.Net
|
||||
|
||||
wgAddress wgaddr.Address
|
||||
|
||||
remoteForwardListeners map[ForwardKey]net.Listener
|
||||
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
|
||||
|
||||
jwtValidator *jwt.Validator
|
||||
jwtExtractor *jwt.ClaimsExtractor
|
||||
jwtConfig *JWTConfig
|
||||
|
||||
suSupportsPty bool
|
||||
}
|
||||
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
Audience string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
}
|
||||
|
||||
// Config contains all SSH server configuration options
|
||||
type Config struct {
|
||||
// JWT authentication configuration. If nil, JWT authentication is disabled
|
||||
JWT *JWTConfig
|
||||
|
||||
// HostKey is the SSH server host key in PEM format
|
||||
HostKeyPEM []byte
|
||||
}
|
||||
|
||||
// SessionInfo contains information about an active SSH session
|
||||
type SessionInfo struct {
|
||||
Username string
|
||||
RemoteAddress string
|
||||
Command string
|
||||
JWTUsername string
|
||||
}
|
||||
|
||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||
// If jwtConfig is nil, JWT authentication is disabled
|
||||
func New(config *Config) *Server {
|
||||
s := &Server{
|
||||
mu: sync.RWMutex{},
|
||||
hostKeyPEM: config.HostKeyPEM,
|
||||
sessions: make(map[SessionKey]ssh.Session),
|
||||
sessionJWTUsers: make(map[SessionKey]string),
|
||||
pendingAuthJWT: make(map[authKey]string),
|
||||
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||
jwtEnabled: config.JWT != nil,
|
||||
jwtConfig: config.JWT,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start runs the SSH server
|
||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer != nil {
|
||||
return errors.New("SSH server is already running")
|
||||
}
|
||||
|
||||
s.suSupportsPty = s.detectSuPtySupport(ctx)
|
||||
|
||||
ln, addrDesc, err := s.createListener(ctx, addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create listener: %w", err)
|
||||
}
|
||||
|
||||
sshServer, err := s.createSSHServer(ln.Addr())
|
||||
if err != nil {
|
||||
s.closeListener(ln)
|
||||
return fmt.Errorf("create SSH server: %w", err)
|
||||
}
|
||||
|
||||
s.sshServer = sshServer
|
||||
log.Infof("SSH server started on %s", addrDesc)
|
||||
|
||||
go func() {
|
||||
if err := sshServer.Serve(ln); err != nil && !errors.Is(err, ssh.ErrServerClosed) {
|
||||
log.Errorf("SSH server error: %v", err)
|
||||
}
|
||||
}()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
|
||||
if s.netstackNet != nil {
|
||||
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("listen on netstack: %w", err)
|
||||
}
|
||||
return ln, fmt.Sprintf("netstack %s", addr), nil
|
||||
}
|
||||
|
||||
tcpAddr := net.TCPAddrFromAddrPort(addr)
|
||||
lc := net.ListenConfig{}
|
||||
ln, err := lc.Listen(ctx, "tcp", tcpAddr.String())
|
||||
if err != nil {
|
||||
return nil, "", fmt.Errorf("listen: %w", err)
|
||||
}
|
||||
return ln, addr.String(), nil
|
||||
}
|
||||
|
||||
func (s *Server) closeListener(ln net.Listener) {
|
||||
if ln == nil {
|
||||
return
|
||||
}
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("listener close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.sshServer == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.sshServer.Close(); err != nil {
|
||||
log.Debugf("close SSH server: %v", err)
|
||||
}
|
||||
|
||||
s.sshServer = nil
|
||||
|
||||
maps.Clear(s.sessions)
|
||||
maps.Clear(s.sessionJWTUsers)
|
||||
maps.Clear(s.pendingAuthJWT)
|
||||
maps.Clear(s.sshConnections)
|
||||
|
||||
for _, cancelFunc := range s.sessionCancels {
|
||||
cancelFunc()
|
||||
}
|
||||
maps.Clear(s.sessionCancels)
|
||||
|
||||
for _, listener := range s.remoteForwardListeners {
|
||||
if err := listener.Close(); err != nil {
|
||||
log.Debugf("close remote forward listener: %v", err)
|
||||
}
|
||||
}
|
||||
maps.Clear(s.remoteForwardListeners)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetStatus returns the current status of the SSH server and active sessions
|
||||
func (s *Server) GetStatus() (enabled bool, sessions []SessionInfo) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
enabled = s.sshServer != nil
|
||||
|
||||
for sessionKey, session := range s.sessions {
|
||||
cmd := "<interactive shell>"
|
||||
if len(session.Command()) > 0 {
|
||||
cmd = safeLogCommand(session.Command())
|
||||
}
|
||||
|
||||
jwtUsername := s.sessionJWTUsers[sessionKey]
|
||||
|
||||
sessions = append(sessions, SessionInfo{
|
||||
Username: session.User(),
|
||||
RemoteAddress: session.RemoteAddr().String(),
|
||||
Command: cmd,
|
||||
JWTUsername: jwtUsername,
|
||||
})
|
||||
}
|
||||
|
||||
return enabled, sessions
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace networking
|
||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.netstackNet = net
|
||||
}
|
||||
|
||||
// SetNetworkValidation configures network-based connection filtering
|
||||
func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
s.mu.RLock()
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
s.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
config := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
if config == nil {
|
||||
return fmt.Errorf("JWT config not set")
|
||||
}
|
||||
|
||||
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
|
||||
|
||||
validator := jwt.NewValidator(
|
||||
config.Issuer,
|
||||
[]string{config.Audience},
|
||||
config.KeysLocation,
|
||||
true,
|
||||
)
|
||||
|
||||
extractor := jwt.NewClaimsExtractor(
|
||||
jwt.WithAudience(config.Audience),
|
||||
)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.jwtValidator = validator
|
||||
s.jwtExtractor = extractor
|
||||
|
||||
log.Infof("JWT validator initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
||||
s.mu.RLock()
|
||||
jwtValidator := s.jwtValidator
|
||||
jwtConfig := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
if jwtValidator == nil {
|
||||
return nil, fmt.Errorf("JWT validator not initialized")
|
||||
}
|
||||
|
||||
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
|
||||
if err != nil {
|
||||
if jwtConfig != nil {
|
||||
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
|
||||
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
|
||||
}
|
||||
}
|
||||
return nil, fmt.Errorf("validate token: %w", err)
|
||||
}
|
||||
|
||||
if err := s.checkTokenAge(token, jwtConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||
if jwtConfig == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
maxTokenAge := jwtConfig.MaxTokenAge
|
||||
if maxTokenAge <= 0 {
|
||||
maxTokenAge = DefaultJWTMaxTokenAge
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
||||
}
|
||||
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
tokenAge := time.Since(issuedAt)
|
||||
maxAge := time.Duration(maxTokenAge) * time.Second
|
||||
if tokenAge > maxAge {
|
||||
userID := getUserIDFromClaims(claims)
|
||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*auth.UserAuth, error) {
|
||||
s.mu.RLock()
|
||||
jwtExtractor := s.jwtExtractor
|
||||
s.mu.RUnlock()
|
||||
|
||||
if jwtExtractor == nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
|
||||
}
|
||||
|
||||
userAuth, err := jwtExtractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
|
||||
}
|
||||
|
||||
if !s.hasSSHAccess(&userAuth) {
|
||||
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
|
||||
}
|
||||
|
||||
return &userAuth, nil
|
||||
}
|
||||
|
||||
func (s *Server) hasSSHAccess(userAuth *auth.UserAuth) bool {
|
||||
return userAuth.UserId != ""
|
||||
}
|
||||
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("parse claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.validateJWTToken(password)
|
||||
if err != nil {
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
userAuth, err := s.extractAndValidateUser(token)
|
||||
if err != nil {
|
||||
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
key := newAuthKey(ctx.User(), ctx.RemoteAddr())
|
||||
s.mu.Lock()
|
||||
s.pendingAuthJWT[key] = userAuth.UserId
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if state, exists := s.sshConnections[sshConn]; exists {
|
||||
state.hasActivePortForward = true
|
||||
} else {
|
||||
s.sshConnections[sshConn] = &sshConnectionState{
|
||||
hasActivePortForward: true,
|
||||
username: username,
|
||||
remoteAddr: remoteAddr,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
|
||||
// We can't extract the SSH connection from net.Conn directly
|
||||
// Connection cleanup will happen during session cleanup or via timeout
|
||||
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||
if ctx == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// Try to match by SSH connection
|
||||
sshConn := ctx.Value(ssh.ContextKeyConn)
|
||||
if sshConn == nil {
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
|
||||
// Look through sessions to find one with matching connection
|
||||
for sessionKey, session := range s.sessions {
|
||||
if session.Context().Value(ssh.ContextKeyConn) == sshConn {
|
||||
return sessionKey
|
||||
}
|
||||
}
|
||||
|
||||
// If no session found, this might be during early connection setup
|
||||
// Return a temporary key that we'll fix up later
|
||||
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
||||
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
||||
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
|
||||
return tempKey
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
s.mu.RLock()
|
||||
netbirdNetwork := s.wgAddress.Network
|
||||
localIP := s.wgAddress.IP
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !netbirdNetwork.IsValid() || !localIP.IsValid() {
|
||||
return conn
|
||||
}
|
||||
|
||||
remoteAddr := conn.RemoteAddr()
|
||||
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
if !ok {
|
||||
log.Warnf("SSH connection rejected: invalid remote IP %s", tcpAddr.IP)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Block connections from our own IP (prevent local apps from connecting to ourselves)
|
||||
if remoteIP == localIP {
|
||||
log.Warnf("SSH connection rejected from own IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
if !netbirdNetwork.Contains(remoteIP) {
|
||||
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("SSH connection from NetBird peer %s allowed", tcpAddr)
|
||||
return conn
|
||||
}
|
||||
|
||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
if err := enableUserSwitching(); err != nil {
|
||||
log.Warnf("failed to enable user switching: %v", err)
|
||||
}
|
||||
|
||||
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
|
||||
if s.jwtEnabled {
|
||||
serverVersion += " " + detection.JWTRequiredMarker
|
||||
}
|
||||
|
||||
server := &ssh.Server{
|
||||
Addr: addr.String(),
|
||||
Handler: s.sessionHandler,
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": s.sftpSubsystemHandler,
|
||||
},
|
||||
HostSigners: []ssh.Signer{},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-tcpip": s.directTCPIPHandler,
|
||||
},
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": s.tcpipForwardHandler,
|
||||
"cancel-tcpip-forward": s.cancelTcpipForwardHandler,
|
||||
},
|
||||
ConnCallback: s.connectionValidator,
|
||||
ConnectionFailedCallback: s.connectionCloseHandler,
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
if s.jwtEnabled {
|
||||
server.PasswordHandler = s.passwordHandler
|
||||
}
|
||||
|
||||
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
|
||||
if err := server.SetOption(hostKeyPEM); err != nil {
|
||||
return nil, fmt.Errorf("set host key: %w", err)
|
||||
}
|
||||
|
||||
s.configurePortForwarding(server)
|
||||
return server, nil
|
||||
}
|
||||
|
||||
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.remoteForwardListeners[key] = ln
|
||||
}
|
||||
|
||||
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
ln, exists := s.remoteForwardListeners[key]
|
||||
if !exists {
|
||||
return false
|
||||
}
|
||||
|
||||
delete(s.remoteForwardListeners, key)
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("remote forward listener close error: %v", err)
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
||||
var payload struct {
|
||||
Host string
|
||||
Port uint32
|
||||
OriginatorAddr string
|
||||
OriginatorPort uint32
|
||||
}
|
||||
|
||||
if err := cryptossh.Unmarshal(newChan.ExtraData(), &payload); err != nil {
|
||||
if err := newChan.Reject(cryptossh.ConnectionFailed, "parse payload"); err != nil {
|
||||
log.Debugf("channel reject error: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
allowLocal := s.allowLocalPortForwarding
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !allowLocal {
|
||||
log.Warnf("local port forwarding denied for %s:%d: disabled by configuration", payload.Host, payload.Port)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "local port forwarding disabled")
|
||||
return
|
||||
}
|
||||
|
||||
// Check privilege requirements for the destination port
|
||||
if err := s.checkPortForwardingPrivileges(ctx, "local", payload.Port); err != nil {
|
||||
log.Warnf("local port forwarding denied for %s:%d: %v", payload.Host, payload.Port, err)
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "insufficient privileges")
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("local port forwarding: %s:%d", payload.Host, payload.Port)
|
||||
|
||||
ssh.DirectTCPIPHandler(srv, conn, newChan, ctx)
|
||||
}
|
||||
394
client/ssh/server/server_config_test.go
Normal file
394
client/ssh/server/server_config_test.go
Normal file
@@ -0,0 +1,394 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
)
|
||||
|
||||
func TestServer_RootLoginRestriction(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowRoot bool
|
||||
username string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "root login allowed",
|
||||
allowRoot: true,
|
||||
username: "root",
|
||||
expectError: false,
|
||||
description: "Root login should succeed when allowed",
|
||||
},
|
||||
{
|
||||
name: "root login denied",
|
||||
allowRoot: false,
|
||||
username: "root",
|
||||
expectError: true,
|
||||
description: "Root login should fail when disabled",
|
||||
},
|
||||
{
|
||||
name: "regular user login always allowed",
|
||||
allowRoot: false,
|
||||
username: "testuser",
|
||||
expectError: false,
|
||||
description: "Regular user login should work regardless of root setting",
|
||||
},
|
||||
}
|
||||
|
||||
// Add Windows Administrator tests if on Windows
|
||||
if runtime.GOOS == "windows" {
|
||||
tests = append(tests, []struct {
|
||||
name string
|
||||
allowRoot bool
|
||||
username string
|
||||
expectError bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "Administrator login allowed",
|
||||
allowRoot: true,
|
||||
username: "Administrator",
|
||||
expectError: false,
|
||||
description: "Administrator login should succeed when allowed",
|
||||
},
|
||||
{
|
||||
name: "Administrator login denied",
|
||||
allowRoot: false,
|
||||
username: "Administrator",
|
||||
expectError: true,
|
||||
description: "Administrator login should fail when disabled",
|
||||
},
|
||||
{
|
||||
name: "administrator login denied (lowercase)",
|
||||
allowRoot: false,
|
||||
username: "administrator",
|
||||
expectError: true,
|
||||
description: "administrator login should fail when disabled (case insensitive)",
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock privileged environment to test root access controls
|
||||
// Set up mock users based on platform
|
||||
mockUsers := map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
"testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"),
|
||||
}
|
||||
|
||||
// Add Windows-specific users for Administrator tests
|
||||
if runtime.GOOS == "windows" {
|
||||
mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator")
|
||||
mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator")
|
||||
}
|
||||
|
||||
cleanup := setupTestDependencies(
|
||||
createTestUser("root", "0", "0", "/root"), // Running as root
|
||||
nil,
|
||||
runtime.GOOS,
|
||||
0, // euid 0 (root)
|
||||
mockUsers,
|
||||
nil,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
// Create server with specific configuration
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(tt.allowRoot)
|
||||
|
||||
// Test the userNameLookup method directly
|
||||
user, err := server.userNameLookup(tt.username)
|
||||
|
||||
if tt.expectError {
|
||||
assert.Error(t, err, tt.description)
|
||||
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
|
||||
// Check for appropriate error message based on platform capabilities
|
||||
errorMsg := err.Error()
|
||||
// Either privileged user restriction OR user switching limitation
|
||||
hasPrivilegedError := strings.Contains(errorMsg, "privileged user")
|
||||
hasSwitchingError := strings.Contains(errorMsg, "cannot switch") || strings.Contains(errorMsg, "user switching not supported")
|
||||
assert.True(t, hasPrivilegedError || hasSwitchingError,
|
||||
"Expected privileged user or user switching error, got: %s", errorMsg)
|
||||
}
|
||||
} else {
|
||||
if tt.username == "root" || strings.ToLower(tt.username) == "administrator" {
|
||||
// For privileged users, we expect either success or a different error
|
||||
// (like user not found), but not the "login disabled" error
|
||||
if err != nil {
|
||||
assert.NotContains(t, err.Error(), "privileged user login is disabled")
|
||||
}
|
||||
} else {
|
||||
// For regular users, lookup should generally succeed or fall back gracefully
|
||||
// Note: may return current user as fallback
|
||||
assert.NotNil(t, user)
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_PortForwardingRestriction(t *testing.T) {
|
||||
// Test that the port forwarding callbacks properly respect configuration flags
|
||||
// This is a unit test of the callback logic, not a full integration test
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowLocalForwarding bool
|
||||
allowRemoteForwarding bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
name: "all forwarding allowed",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: true,
|
||||
description: "Both local and remote forwarding should be allowed",
|
||||
},
|
||||
{
|
||||
name: "local forwarding disabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: true,
|
||||
description: "Local forwarding should be denied when disabled",
|
||||
},
|
||||
{
|
||||
name: "remote forwarding disabled",
|
||||
allowLocalForwarding: true,
|
||||
allowRemoteForwarding: false,
|
||||
description: "Remote forwarding should be denied when disabled",
|
||||
},
|
||||
{
|
||||
name: "all forwarding disabled",
|
||||
allowLocalForwarding: false,
|
||||
allowRemoteForwarding: false,
|
||||
description: "Both forwarding types should be denied when disabled",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create server with specific configuration
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
|
||||
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
|
||||
|
||||
// We need to access the internal configuration to simulate the callback tests
|
||||
// Since the callbacks are created inside the Start method, we'll test the logic directly
|
||||
|
||||
// Test the configuration values are set correctly
|
||||
server.mu.RLock()
|
||||
allowLocal := server.allowLocalPortForwarding
|
||||
allowRemote := server.allowRemotePortForwarding
|
||||
server.mu.RUnlock()
|
||||
|
||||
assert.Equal(t, tt.allowLocalForwarding, allowLocal, "Local forwarding configuration should be set correctly")
|
||||
assert.Equal(t, tt.allowRemoteForwarding, allowRemote, "Remote forwarding configuration should be set correctly")
|
||||
|
||||
// Simulate the callback logic
|
||||
localResult := allowLocal // This would be the callback return value
|
||||
remoteResult := allowRemote // This would be the callback return value
|
||||
|
||||
assert.Equal(t, tt.allowLocalForwarding, localResult,
|
||||
"Local port forwarding callback should return correct value")
|
||||
assert.Equal(t, tt.allowRemoteForwarding, remoteResult,
|
||||
"Remote port forwarding callback should return correct value")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_PortConflictHandling(t *testing.T) {
|
||||
// Test that multiple sessions requesting the same local port are handled naturally by the OS
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Get a free port for testing
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
testPort := ln.Addr().(*net.TCPAddr).Port
|
||||
err = ln.Close()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Connect first client
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client1.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Connect second client
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client2.Close()
|
||||
assert.NoError(t, err)
|
||||
}()
|
||||
|
||||
// First client binds to the test port
|
||||
localAddr1 := fmt.Sprintf("127.0.0.1:%d", testPort)
|
||||
remoteAddr := "127.0.0.1:80"
|
||||
|
||||
// Start first client's port forwarding
|
||||
done1 := make(chan error, 1)
|
||||
go func() {
|
||||
// This should succeed and hold the port
|
||||
err := client1.LocalPortForward(ctx1, localAddr1, remoteAddr)
|
||||
done1 <- err
|
||||
}()
|
||||
|
||||
// Give first client time to bind
|
||||
time.Sleep(200 * time.Millisecond)
|
||||
|
||||
// Second client tries to bind to same port
|
||||
localAddr2 := fmt.Sprintf("127.0.0.1:%d", testPort)
|
||||
|
||||
shortCtx, shortCancel := context.WithTimeout(context.Background(), 1*time.Second)
|
||||
defer shortCancel()
|
||||
|
||||
err = client2.LocalPortForward(shortCtx, localAddr2, remoteAddr)
|
||||
// Second client should fail due to "address already in use"
|
||||
assert.Error(t, err, "Second client should fail to bind to same port")
|
||||
if err != nil {
|
||||
// The error should indicate the address is already in use
|
||||
errMsg := strings.ToLower(err.Error())
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, errMsg, "only one usage of each socket address",
|
||||
"Error should indicate port conflict")
|
||||
} else {
|
||||
assert.Contains(t, errMsg, "address already in use",
|
||||
"Error should indicate port conflict")
|
||||
}
|
||||
}
|
||||
|
||||
// Cancel first client's context and wait for it to finish
|
||||
cancel1()
|
||||
select {
|
||||
case err1 := <-done1:
|
||||
// Should get context cancelled or deadline exceeded
|
||||
assert.Error(t, err1, "First client should exit when context cancelled")
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Error("First client did not exit within timeout")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_IsPrivilegedUser(t *testing.T) {
|
||||
|
||||
tests := []struct {
|
||||
username string
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
username: "root",
|
||||
expected: true,
|
||||
description: "root should be considered privileged",
|
||||
},
|
||||
{
|
||||
username: "regular",
|
||||
expected: false,
|
||||
description: "regular user should not be privileged",
|
||||
},
|
||||
{
|
||||
username: "",
|
||||
expected: false,
|
||||
description: "empty username should not be privileged",
|
||||
},
|
||||
}
|
||||
|
||||
// Add Windows-specific tests
|
||||
if runtime.GOOS == "windows" {
|
||||
tests = append(tests, []struct {
|
||||
username string
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
username: "Administrator",
|
||||
expected: true,
|
||||
description: "Administrator should be considered privileged on Windows",
|
||||
},
|
||||
{
|
||||
username: "administrator",
|
||||
expected: true,
|
||||
description: "administrator should be considered privileged on Windows (case insensitive)",
|
||||
},
|
||||
}...)
|
||||
} else {
|
||||
// On non-Windows systems, Administrator should not be privileged
|
||||
tests = append(tests, []struct {
|
||||
username string
|
||||
expected bool
|
||||
description string
|
||||
}{
|
||||
{
|
||||
username: "Administrator",
|
||||
expected: false,
|
||||
description: "Administrator should not be privileged on non-Windows systems",
|
||||
},
|
||||
}...)
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.description, func(t *testing.T) {
|
||||
result := isPrivilegedUsername(tt.username)
|
||||
assert.Equal(t, tt.expected, result, tt.description)
|
||||
})
|
||||
}
|
||||
}
|
||||
441
client/ssh/server/server_test.go
Normal file
441
client/ssh/server/server_test.go
Normal file
@@ -0,0 +1,441 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestServer_StartStop(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: key,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestSSHServerIntegration(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with random port
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server in background
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// Get a free port
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key for verification
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
// Create SSH client config
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// Connect to SSH server
|
||||
client, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
t.Logf("close client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test creating a session
|
||||
session, err := client.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := session.Close(); err != nil {
|
||||
t.Logf("close session: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Note: Since we don't have a real shell environment in tests,
|
||||
// we can't test actual command execution, but we can verify
|
||||
// the connection and authentication work
|
||||
t.Log("SSH connection and authentication successful")
|
||||
}
|
||||
|
||||
func TestSSHServerMultipleConnections(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// Test multiple concurrent connections
|
||||
const numConnections = 5
|
||||
results := make(chan error, numConnections)
|
||||
|
||||
for i := 0; i < numConnections; i++ {
|
||||
go func(id int) {
|
||||
client, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
if err != nil {
|
||||
results <- fmt.Errorf("connection %d failed: %w", id, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = client.Close() // Ignore error in test goroutine
|
||||
}()
|
||||
|
||||
session, err := client.NewSession()
|
||||
if err != nil {
|
||||
results <- fmt.Errorf("session %d failed: %w", id, err)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
_ = session.Close() // Ignore error in test goroutine
|
||||
}()
|
||||
|
||||
results <- nil
|
||||
}(i)
|
||||
}
|
||||
|
||||
// Wait for all connections to complete
|
||||
for i := 0; i < numConnections; i++ {
|
||||
select {
|
||||
case err := <-results:
|
||||
assert.NoError(t, err)
|
||||
case <-time.After(10 * time.Second):
|
||||
t.Fatalf("Connection %d timed out", i)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
// Generate host key for server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Generate a client private key for SSH protocol (server doesn't check it)
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
// Try to connect with client key
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(clientSigner),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// This should succeed in no-auth mode (server doesn't verify keys)
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
assert.NoError(t, err, "Connection should succeed in no-auth mode")
|
||||
if conn != nil {
|
||||
assert.NoError(t, conn.Close())
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServerStartStopCycle(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
serverAddr := "127.0.0.1:0"
|
||||
|
||||
// Test multiple start/stop cycles
|
||||
for i := 0; i < 3; i++ {
|
||||
t.Logf("Start/stop cycle %d", i+1)
|
||||
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-started:
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Cycle %d: Server failed to start: %v", i+1, err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatalf("Cycle %d: Server start timeout", i+1)
|
||||
}
|
||||
|
||||
err = server.Stop()
|
||||
require.NoError(t, err, "Cycle %d: Stop should succeed", i+1)
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServer_WindowsShellHandling(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping Windows shell test in short mode")
|
||||
}
|
||||
|
||||
server := &Server{}
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
// Test Windows cmd.exe shell behavior
|
||||
args := server.getShellCommandArgs("cmd.exe", "echo test")
|
||||
assert.Equal(t, "cmd.exe", args[0])
|
||||
assert.Equal(t, "-Command", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
|
||||
// Test PowerShell behavior
|
||||
args = server.getShellCommandArgs("powershell.exe", "echo test")
|
||||
assert.Equal(t, "powershell.exe", args[0])
|
||||
assert.Equal(t, "-Command", args[1])
|
||||
assert.Equal(t, "echo test", args[2])
|
||||
} else {
|
||||
// Test Unix shell behavior
|
||||
args := server.getShellCommandArgs("/bin/sh", "echo test")
|
||||
assert.Equal(t, "/bin/sh", args[0])
|
||||
assert.Equal(t, "-l", args[1])
|
||||
assert.Equal(t, "-c", args[2])
|
||||
assert.Equal(t, "echo test", args[3])
|
||||
}
|
||||
}
|
||||
|
||||
func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig1 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server1 := New(serverConfig1)
|
||||
|
||||
serverConfig2 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server2 := New(serverConfig2)
|
||||
|
||||
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
|
||||
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
|
||||
|
||||
server2.SetAllowLocalPortForwarding(true)
|
||||
server2.SetAllowRemotePortForwarding(true)
|
||||
|
||||
assert.True(t, server2.allowLocalPortForwarding, "Local port forwarding should be enabled when explicitly set")
|
||||
assert.True(t, server2.allowRemotePortForwarding, "Remote port forwarding should be enabled when explicitly set")
|
||||
}
|
||||
168
client/ssh/server/session_handlers.go
Normal file
168
client/ssh/server/session_handlers.go
Normal file
@@ -0,0 +1,168 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// sessionHandler handles SSH sessions
|
||||
func (s *Server) sessionHandler(session ssh.Session) {
|
||||
sessionKey := s.registerSession(session)
|
||||
|
||||
key := newAuthKey(session.User(), session.RemoteAddr())
|
||||
s.mu.Lock()
|
||||
jwtUsername := s.pendingAuthJWT[key]
|
||||
if jwtUsername != "" {
|
||||
s.sessionJWTUsers[sessionKey] = jwtUsername
|
||||
delete(s.pendingAuthJWT, key)
|
||||
}
|
||||
s.mu.Unlock()
|
||||
|
||||
logger := log.WithField("session", sessionKey)
|
||||
if jwtUsername != "" {
|
||||
logger = logger.WithField("jwt_user", jwtUsername)
|
||||
logger.Infof("SSH session started (JWT user: %s)", jwtUsername)
|
||||
} else {
|
||||
logger.Infof("SSH session started")
|
||||
}
|
||||
sessionStart := time.Now()
|
||||
|
||||
defer s.unregisterSession(sessionKey, session)
|
||||
defer func() {
|
||||
duration := time.Since(sessionStart).Round(time.Millisecond)
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
logger.Warnf("close session after %v: %v", duration, err)
|
||||
}
|
||||
logger.Infof("SSH session closed after %v", duration)
|
||||
}()
|
||||
|
||||
privilegeResult, err := s.userPrivilegeCheck(session.User())
|
||||
if err != nil {
|
||||
s.handlePrivError(logger, session, err)
|
||||
return
|
||||
}
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
hasCommand := len(session.Command()) > 0
|
||||
|
||||
switch {
|
||||
case isPty && hasCommand:
|
||||
// ssh -t <host> <cmd> - Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, winCh)
|
||||
case isPty:
|
||||
// ssh <host> - Pty interactive session (login)
|
||||
s.handlePty(logger, session, privilegeResult, ptyReq, winCh)
|
||||
case hasCommand:
|
||||
// ssh <host> <cmd> - non-Pty command execution
|
||||
s.handleCommand(logger, session, privilegeResult, nil)
|
||||
default:
|
||||
s.rejectInvalidSession(logger, session)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *Server) rejectInvalidSession(logger *log.Entry, session ssh.Session) {
|
||||
if _, err := io.WriteString(session, "no command specified and Pty not requested\n"); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
logger.Infof("rejected non-Pty session without command from %s", session.RemoteAddr())
|
||||
}
|
||||
|
||||
func (s *Server) registerSession(session ssh.Session) SessionKey {
|
||||
sessionID := session.Context().Value(ssh.ContextKeySessionID)
|
||||
if sessionID == nil {
|
||||
sessionID = fmt.Sprintf("%p", session)
|
||||
}
|
||||
|
||||
// Create a short 4-byte identifier from the full session ID
|
||||
hasher := sha256.New()
|
||||
hasher.Write([]byte(fmt.Sprintf("%v", sessionID)))
|
||||
hash := hasher.Sum(nil)
|
||||
shortID := hex.EncodeToString(hash[:4])
|
||||
|
||||
remoteAddr := session.RemoteAddr().String()
|
||||
username := session.User()
|
||||
sessionKey := SessionKey(fmt.Sprintf("%s@%s-%s", username, remoteAddr, shortID))
|
||||
|
||||
s.mu.Lock()
|
||||
s.sessions[sessionKey] = session
|
||||
s.mu.Unlock()
|
||||
|
||||
return sessionKey
|
||||
}
|
||||
|
||||
func (s *Server) unregisterSession(sessionKey SessionKey, session ssh.Session) {
|
||||
s.mu.Lock()
|
||||
delete(s.sessions, sessionKey)
|
||||
delete(s.sessionJWTUsers, sessionKey)
|
||||
|
||||
// Cancel all port forwarding connections for this session
|
||||
var connectionsToCancel []ConnectionKey
|
||||
for key := range s.sessionCancels {
|
||||
if strings.HasPrefix(string(key), string(sessionKey)+"-") {
|
||||
connectionsToCancel = append(connectionsToCancel, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, key := range connectionsToCancel {
|
||||
if cancelFunc, exists := s.sessionCancels[key]; exists {
|
||||
log.WithField("session", sessionKey).Debugf("cancelling port forwarding context: %s", key)
|
||||
cancelFunc()
|
||||
delete(s.sessionCancels, key)
|
||||
}
|
||||
}
|
||||
|
||||
if sshConnValue := session.Context().Value(ssh.ContextKeyConn); sshConnValue != nil {
|
||||
if sshConn, ok := sshConnValue.(*cryptossh.ServerConn); ok {
|
||||
delete(s.sshConnections, sshConn)
|
||||
}
|
||||
}
|
||||
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
func (s *Server) handlePrivError(logger *log.Entry, session ssh.Session, err error) {
|
||||
logger.Warnf("user privilege check failed: %v", err)
|
||||
|
||||
errorMsg := s.buildUserLookupErrorMessage(err)
|
||||
|
||||
if _, writeErr := fmt.Fprint(session, errorMsg); writeErr != nil {
|
||||
logger.Debugf(errWriteSession, writeErr)
|
||||
}
|
||||
if exitErr := session.Exit(1); exitErr != nil {
|
||||
logSessionExitError(logger, exitErr)
|
||||
}
|
||||
}
|
||||
|
||||
// buildUserLookupErrorMessage creates appropriate user-facing error messages based on error type
|
||||
func (s *Server) buildUserLookupErrorMessage(err error) string {
|
||||
var privilegedErr *PrivilegedUserError
|
||||
|
||||
switch {
|
||||
case errors.As(err, &privilegedErr):
|
||||
if privilegedErr.Username == "root" {
|
||||
return "root login is disabled on this SSH server\n"
|
||||
}
|
||||
return "privileged user access is disabled on this SSH server\n"
|
||||
|
||||
case errors.Is(err, ErrPrivilegeRequired):
|
||||
return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n"
|
||||
|
||||
case errors.Is(err, ErrPrivilegedUserSwitch):
|
||||
return "Cannot switch to privileged user - current user lacks required privileges\n"
|
||||
|
||||
default:
|
||||
return "User authentication failed\n"
|
||||
}
|
||||
}
|
||||
22
client/ssh/server/session_handlers_js.go
Normal file
22
client/ssh/server/session_handlers_js.go
Normal file
@@ -0,0 +1,22 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// handlePty is not supported on JS/WASM
|
||||
func (s *Server) handlePty(logger *log.Entry, session ssh.Session, _ PrivilegeCheckResult, _ ssh.Pty, _ <-chan ssh.Window) bool {
|
||||
errorMsg := "PTY sessions are not supported on WASM/JS platform\n"
|
||||
if _, err := fmt.Fprint(session.Stderr(), errorMsg); err != nil {
|
||||
logger.Debugf(errWriteSession, err)
|
||||
}
|
||||
if err := session.Exit(1); err != nil {
|
||||
logSessionExitError(logger, err)
|
||||
}
|
||||
return false
|
||||
}
|
||||
81
client/ssh/server/sftp.go
Normal file
81
client/ssh/server/sftp.go
Normal file
@@ -0,0 +1,81 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/pkg/sftp"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// SetAllowSFTP enables or disables SFTP support
|
||||
func (s *Server) SetAllowSFTP(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowSFTP = allow
|
||||
}
|
||||
|
||||
// sftpSubsystemHandler handles SFTP subsystem requests
|
||||
func (s *Server) sftpSubsystemHandler(sess ssh.Session) {
|
||||
s.mu.RLock()
|
||||
allowSFTP := s.allowSFTP
|
||||
s.mu.RUnlock()
|
||||
|
||||
if !allowSFTP {
|
||||
log.Debugf("SFTP subsystem request denied: SFTP disabled")
|
||||
if err := sess.Exit(1); err != nil {
|
||||
log.Debugf("SFTP session exit failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: sess.User(),
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: FeatureSFTP,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
log.Warnf("SFTP access denied for user %s from %s: %v", sess.User(), sess.RemoteAddr(), result.Error)
|
||||
if err := sess.Exit(1); err != nil {
|
||||
log.Debugf("exit SFTP session: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("SFTP subsystem request from user %s (effective user %s)", sess.User(), result.User.Username)
|
||||
|
||||
if !result.RequiresUserSwitching {
|
||||
if err := s.executeSftpDirect(sess); err != nil {
|
||||
log.Errorf("SFTP direct execution: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
if err := s.executeSftpWithPrivilegeDrop(sess, result.User); err != nil {
|
||||
log.Errorf("SFTP privilege drop execution: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// executeSftpDirect executes SFTP directly without privilege dropping
|
||||
func (s *Server) executeSftpDirect(sess ssh.Session) error {
|
||||
log.Debugf("starting SFTP session for user %s (no privilege dropping)", sess.User())
|
||||
|
||||
sftpServer, err := sftp.NewServer(sess)
|
||||
if err != nil {
|
||||
return fmt.Errorf("SFTP server creation: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := sftpServer.Close(); err != nil {
|
||||
log.Debugf("failed to close sftp server: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := sftpServer.Serve(); err != nil && err != io.EOF {
|
||||
return fmt.Errorf("serve: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
12
client/ssh/server/sftp_js.go
Normal file
12
client/ssh/server/sftp_js.go
Normal file
@@ -0,0 +1,12 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"os/user"
|
||||
)
|
||||
|
||||
// parseUserCredentials is not supported on JS/WASM
|
||||
func (s *Server) parseUserCredentials(_ *user.User) (uint32, uint32, []uint32, error) {
|
||||
return 0, 0, nil, errNotSupported
|
||||
}
|
||||
228
client/ssh/server/sftp_test.go
Normal file
228
client/ssh/server/sftp_test.go
Normal file
@@ -0,0 +1,228 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/user"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/pkg/sftp"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestSSHServer_SFTPSubsystem(t *testing.T) {
|
||||
// Skip SFTP test when running as root due to protocol issues in some environments
|
||||
if os.Geteuid() == 0 {
|
||||
t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues")
|
||||
}
|
||||
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP enabled
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(true)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// (currentUser already obtained at function start)
|
||||
|
||||
// Create SSH client connection
|
||||
clientConfig := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
|
||||
require.NoError(t, err, "SSH connection should succeed")
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Create SFTP client
|
||||
sftpClient, err := sftp.NewClient(conn)
|
||||
require.NoError(t, err, "SFTP client creation should succeed")
|
||||
defer func() {
|
||||
if err := sftpClient.Close(); err != nil {
|
||||
t.Logf("SFTP client close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Test basic SFTP operations
|
||||
workingDir, err := sftpClient.Getwd()
|
||||
assert.NoError(t, err, "Should be able to get working directory")
|
||||
assert.NotEmpty(t, workingDir, "Working directory should not be empty")
|
||||
|
||||
// Test directory listing
|
||||
files, err := sftpClient.ReadDir(".")
|
||||
assert.NoError(t, err, "Should be able to list current directory")
|
||||
assert.NotNil(t, files, "File list should not be nil")
|
||||
}
|
||||
|
||||
func TestSSHServer_SFTPDisabled(t *testing.T) {
|
||||
// Get current user for SSH connection
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Generate host key for server
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP disabled
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(false)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", serverAddr)
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
serverAddr = actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse client private key
|
||||
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
hostPubKey := hostPrivParsed.PublicKey()
|
||||
|
||||
// (currentUser already obtained at function start)
|
||||
|
||||
// Create SSH client connection
|
||||
clientConfig := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(signer),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 5 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
|
||||
require.NoError(t, err, "SSH connection should succeed")
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Try to create SFTP client - should fail when SFTP is disabled
|
||||
_, err = sftp.NewClient(conn)
|
||||
assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled")
|
||||
}
|
||||
71
client/ssh/server/sftp_unix.go
Normal file
71
client/ssh/server/sftp_unix.go
Normal file
@@ -0,0 +1,71 @@
|
||||
//go:build !windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// executeSftpWithPrivilegeDrop executes SFTP using Unix privilege dropping
|
||||
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
|
||||
uid, gid, groups, err := s.parseUserCredentials(targetUser)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse user credentials: %w", err)
|
||||
}
|
||||
|
||||
sftpCmd, err := s.createSftpExecutorCommand(sess, uid, gid, groups, targetUser.HomeDir)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create executor: %w", err)
|
||||
}
|
||||
|
||||
sftpCmd.Stdin = sess
|
||||
sftpCmd.Stdout = sess
|
||||
sftpCmd.Stderr = sess.Stderr()
|
||||
|
||||
log.Tracef("starting SFTP with privilege dropping to user %s (UID=%d, GID=%d)", targetUser.Username, uid, gid)
|
||||
|
||||
if err := sftpCmd.Start(); err != nil {
|
||||
return fmt.Errorf("starting SFTP executor: %w", err)
|
||||
}
|
||||
|
||||
if err := sftpCmd.Wait(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
log.Tracef("SFTP process exited with code %d", exitError.ExitCode())
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("exec: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createSftpExecutorCommand creates a command that spawns netbird ssh sftp for privilege dropping
|
||||
func (s *Server) createSftpExecutorCommand(sess ssh.Session, uid, gid uint32, groups []uint32, workingDir string) (*exec.Cmd, error) {
|
||||
netbirdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"ssh", "sftp",
|
||||
"--uid", strconv.FormatUint(uint64(uid), 10),
|
||||
"--gid", strconv.FormatUint(uint64(gid), 10),
|
||||
"--working-dir", workingDir,
|
||||
}
|
||||
|
||||
for _, group := range groups {
|
||||
args = append(args, "--groups", strconv.FormatUint(uint64(group), 10))
|
||||
}
|
||||
|
||||
log.Tracef("creating SFTP executor command: %s %v", netbirdPath, args)
|
||||
return exec.CommandContext(sess.Context(), netbirdPath, args...), nil
|
||||
}
|
||||
91
client/ssh/server/sftp_windows.go
Normal file
91
client/ssh/server/sftp_windows.go
Normal file
@@ -0,0 +1,91 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// createSftpCommand creates a Windows SFTP command with user switching.
|
||||
// The caller must close the returned token handle after starting the process.
|
||||
func (s *Server) createSftpCommand(targetUser *user.User, sess ssh.Session) (*exec.Cmd, windows.Token, error) {
|
||||
username, domain := s.parseUsername(targetUser.Username)
|
||||
|
||||
netbirdPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("get netbird executable path: %w", err)
|
||||
}
|
||||
|
||||
args := []string{
|
||||
"ssh", "sftp",
|
||||
"--working-dir", targetUser.HomeDir,
|
||||
"--windows-username", username,
|
||||
"--windows-domain", domain,
|
||||
}
|
||||
|
||||
pd := NewPrivilegeDropper()
|
||||
token, err := pd.createToken(username, domain)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create token: %w", err)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(token); err != nil {
|
||||
log.Warnf("failed to close impersonation token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
cmd, primaryToken, err := pd.createProcessWithToken(sess.Context(), windows.Token(token), netbirdPath, append([]string{netbirdPath}, args...), targetUser.HomeDir)
|
||||
if err != nil {
|
||||
return nil, 0, fmt.Errorf("create SFTP command: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Created Windows SFTP command with user switching for %s", targetUser.Username)
|
||||
return cmd, primaryToken, nil
|
||||
}
|
||||
|
||||
// executeSftpCommand executes a Windows SFTP command with proper I/O handling
|
||||
func (s *Server) executeSftpCommand(sess ssh.Session, sftpCmd *exec.Cmd, token windows.Token) error {
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
sftpCmd.Stdin = sess
|
||||
sftpCmd.Stdout = sess
|
||||
sftpCmd.Stderr = sess.Stderr()
|
||||
|
||||
if err := sftpCmd.Start(); err != nil {
|
||||
return fmt.Errorf("starting sftp executor: %w", err)
|
||||
}
|
||||
|
||||
if err := sftpCmd.Wait(); err != nil {
|
||||
var exitError *exec.ExitError
|
||||
if errors.As(err, &exitError) {
|
||||
log.Tracef("sftp process exited with code %d", exitError.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("exec sftp: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// executeSftpWithPrivilegeDrop executes SFTP using Windows privilege dropping
|
||||
func (s *Server) executeSftpWithPrivilegeDrop(sess ssh.Session, targetUser *user.User) error {
|
||||
sftpCmd, token, err := s.createSftpCommand(targetUser, sess)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create sftp: %w", err)
|
||||
}
|
||||
return s.executeSftpCommand(sess, sftpCmd, token)
|
||||
}
|
||||
180
client/ssh/server/shell.go
Normal file
180
client/ssh/server/shell.go
Normal file
@@ -0,0 +1,180 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
defaultUnixShell = "/bin/sh"
|
||||
|
||||
pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name
|
||||
powershellExe = "powershell.exe"
|
||||
)
|
||||
|
||||
// getUserShell returns the appropriate shell for the given user ID
|
||||
// Handles all platform-specific logic and fallbacks consistently
|
||||
func getUserShell(userID string) string {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
return getWindowsUserShell()
|
||||
default:
|
||||
return getUnixUserShell(userID)
|
||||
}
|
||||
}
|
||||
|
||||
// getWindowsUserShell returns the best shell for Windows users.
|
||||
// We intentionally do not support cmd.exe or COMSPEC fallbacks to avoid command injection
|
||||
// vulnerabilities that arise from cmd.exe's complex command line parsing and special characters.
|
||||
// PowerShell provides safer argument handling and is available on all modern Windows systems.
|
||||
// Order: pwsh.exe -> powershell.exe
|
||||
func getWindowsUserShell() string {
|
||||
if path, err := exec.LookPath(pwshExe); err == nil {
|
||||
return path
|
||||
}
|
||||
if path, err := exec.LookPath(powershellExe); err == nil {
|
||||
return path
|
||||
}
|
||||
|
||||
return `C:\Windows\System32\WindowsPowerShell\v1.0\powershell.exe`
|
||||
}
|
||||
|
||||
// getUnixUserShell returns the shell for Unix-like systems
|
||||
func getUnixUserShell(userID string) string {
|
||||
shell := getShellFromPasswd(userID)
|
||||
if shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
if shell := os.Getenv("SHELL"); shell != "" {
|
||||
return shell
|
||||
}
|
||||
|
||||
return defaultUnixShell
|
||||
}
|
||||
|
||||
// getShellFromPasswd reads the shell from /etc/passwd for the given user ID
|
||||
func getShellFromPasswd(userID string) string {
|
||||
file, err := os.Open("/etc/passwd")
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Warnf("close /etc/passwd file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := scanner.Text()
|
||||
fields := strings.Split(line, ":")
|
||||
if len(fields) < 7 {
|
||||
continue
|
||||
}
|
||||
|
||||
// field 2 is UID
|
||||
if fields[2] == userID {
|
||||
shell := strings.TrimSpace(fields[6])
|
||||
return shell
|
||||
}
|
||||
}
|
||||
|
||||
if err := scanner.Err(); err != nil {
|
||||
log.Warnf("error reading /etc/passwd: %v", err)
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// prepareUserEnv prepares environment variables for user execution
|
||||
func prepareUserEnv(user *user.User, shell string) []string {
|
||||
pathValue := "/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"
|
||||
if runtime.GOOS == "windows" {
|
||||
pathValue = `C:\Windows\System32;C:\Windows;C:\Windows\System32\Wbem;C:\Windows\System32\WindowsPowerShell\v1.0`
|
||||
}
|
||||
|
||||
return []string{
|
||||
fmt.Sprint("SHELL=" + shell),
|
||||
fmt.Sprint("USER=" + user.Username),
|
||||
fmt.Sprint("LOGNAME=" + user.Username),
|
||||
fmt.Sprint("HOME=" + user.HomeDir),
|
||||
"PATH=" + pathValue,
|
||||
}
|
||||
}
|
||||
|
||||
// acceptEnv checks if environment variable from SSH client should be accepted
|
||||
// This is a whitelist of variables that SSH clients can send to the server
|
||||
func acceptEnv(envVar string) bool {
|
||||
varName := envVar
|
||||
if idx := strings.Index(envVar, "="); idx != -1 {
|
||||
varName = envVar[:idx]
|
||||
}
|
||||
|
||||
exactMatches := []string{
|
||||
"LANG",
|
||||
"LANGUAGE",
|
||||
"TERM",
|
||||
"COLORTERM",
|
||||
"EDITOR",
|
||||
"VISUAL",
|
||||
"PAGER",
|
||||
"LESS",
|
||||
"LESSCHARSET",
|
||||
"TZ",
|
||||
}
|
||||
|
||||
prefixMatches := []string{
|
||||
"LC_",
|
||||
}
|
||||
|
||||
for _, exact := range exactMatches {
|
||||
if varName == exact {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
for _, prefix := range prefixMatches {
|
||||
if strings.HasPrefix(varName, prefix) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// prepareSSHEnv prepares SSH protocol-specific environment variables
|
||||
// These variables provide information about the SSH connection itself
|
||||
func prepareSSHEnv(session ssh.Session) []string {
|
||||
remoteAddr := session.RemoteAddr()
|
||||
localAddr := session.LocalAddr()
|
||||
|
||||
remoteHost, remotePort, err := net.SplitHostPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
remoteHost = remoteAddr.String()
|
||||
remotePort = "0"
|
||||
}
|
||||
|
||||
localHost, localPort, err := net.SplitHostPort(localAddr.String())
|
||||
if err != nil {
|
||||
localHost = localAddr.String()
|
||||
localPort = strconv.Itoa(InternalSSHPort)
|
||||
}
|
||||
|
||||
return []string{
|
||||
// SSH_CLIENT format: "client_ip client_port server_port"
|
||||
fmt.Sprintf("SSH_CLIENT=%s %s %s", remoteHost, remotePort, localPort),
|
||||
// SSH_CONNECTION format: "client_ip client_port server_ip server_port"
|
||||
fmt.Sprintf("SSH_CONNECTION=%s %s %s %s", remoteHost, remotePort, localHost, localPort),
|
||||
}
|
||||
}
|
||||
45
client/ssh/server/test.go
Normal file
45
client/ssh/server/test.go
Normal file
@@ -0,0 +1,45 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"testing"
|
||||
"time"
|
||||
)
|
||||
|
||||
func StartTestServer(t *testing.T, server *Server) string {
|
||||
started := make(chan string, 1)
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
actualAddr := ln.Addr().String()
|
||||
if err := ln.Close(); err != nil {
|
||||
errChan <- fmt.Errorf("close temp listener: %w", err)
|
||||
return
|
||||
}
|
||||
|
||||
addrPort := netip.MustParseAddrPort(actualAddr)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
case actualAddr := <-started:
|
||||
return actualAddr
|
||||
case err := <-errChan:
|
||||
t.Fatalf("Server failed to start: %v", err)
|
||||
case <-time.After(5 * time.Second):
|
||||
t.Fatal("Server start timeout")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
411
client/ssh/server/user_utils.go
Normal file
411
client/ssh/server/user_utils.go
Normal file
@@ -0,0 +1,411 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrPrivilegeRequired = errors.New("SeAssignPrimaryTokenPrivilege required for user switching - NetBird must run with elevated privileges")
|
||||
ErrPrivilegedUserSwitch = errors.New("cannot switch to privileged user - current user lacks required privileges")
|
||||
)
|
||||
|
||||
// isPlatformUnix returns true for Unix-like platforms (Linux, macOS, etc.)
|
||||
func isPlatformUnix() bool {
|
||||
return getCurrentOS() != "windows"
|
||||
}
|
||||
|
||||
// Dependency injection variables for testing - allows mocking dynamic runtime checks
|
||||
var (
|
||||
getCurrentUser = user.Current
|
||||
lookupUser = user.Lookup
|
||||
getCurrentOS = func() string { return runtime.GOOS }
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
getEuid = os.Geteuid
|
||||
)
|
||||
|
||||
const (
|
||||
// FeatureSSHLogin represents SSH login operations for privilege checking
|
||||
FeatureSSHLogin = "SSH login"
|
||||
// FeatureSFTP represents SFTP operations for privilege checking
|
||||
FeatureSFTP = "SFTP"
|
||||
)
|
||||
|
||||
// PrivilegeCheckRequest represents a privilege check request
|
||||
type PrivilegeCheckRequest struct {
|
||||
// Username being requested (empty = current user)
|
||||
RequestedUsername string
|
||||
FeatureSupportsUserSwitch bool // Does this feature/operation support user switching?
|
||||
FeatureName string
|
||||
}
|
||||
|
||||
// PrivilegeCheckResult represents the result of a privilege check
|
||||
type PrivilegeCheckResult struct {
|
||||
// Allowed indicates whether the privilege check passed
|
||||
Allowed bool
|
||||
// User is the effective user to use for the operation (nil if not allowed)
|
||||
User *user.User
|
||||
// Error contains the reason for denial (nil if allowed)
|
||||
Error error
|
||||
// UsedFallback indicates we fell back to current user instead of requested user.
|
||||
// This happens on Unix when running as an unprivileged user (e.g., in containers)
|
||||
// where there's no point in user switching since we lack privileges anyway.
|
||||
// When true, all privilege checks have already been performed and no additional
|
||||
// privilege dropping or root checks are needed - the current user is the target.
|
||||
UsedFallback bool
|
||||
// RequiresUserSwitching indicates whether user switching will actually occur
|
||||
// (false for fallback cases where no actual switching happens)
|
||||
RequiresUserSwitching bool
|
||||
}
|
||||
|
||||
// CheckPrivileges performs comprehensive privilege checking for all SSH features.
|
||||
// This is the single source of truth for privilege decisions across the SSH server.
|
||||
func (s *Server) CheckPrivileges(req PrivilegeCheckRequest) PrivilegeCheckResult {
|
||||
context, err := s.buildPrivilegeCheckContext(req.FeatureName)
|
||||
if err != nil {
|
||||
return PrivilegeCheckResult{Allowed: false, Error: err}
|
||||
}
|
||||
|
||||
// Handle empty username case - but still check root access controls
|
||||
if req.RequestedUsername == "" {
|
||||
if isPrivilegedUsername(context.currentUser.Username) && !context.allowRoot {
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: &PrivilegedUserError{Username: context.currentUser.Username},
|
||||
}
|
||||
}
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: true,
|
||||
User: context.currentUser,
|
||||
RequiresUserSwitching: false,
|
||||
}
|
||||
}
|
||||
|
||||
return s.checkUserRequest(context, req)
|
||||
}
|
||||
|
||||
// buildPrivilegeCheckContext gathers all the context needed for privilege checking
|
||||
func (s *Server) buildPrivilegeCheckContext(featureName string) (*privilegeCheckContext, error) {
|
||||
currentUser, err := getCurrentUser()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get current user for %s: %w", featureName, err)
|
||||
}
|
||||
|
||||
s.mu.RLock()
|
||||
allowRoot := s.allowRootLogin
|
||||
s.mu.RUnlock()
|
||||
|
||||
return &privilegeCheckContext{
|
||||
currentUser: currentUser,
|
||||
currentUserPrivileged: getIsProcessPrivileged(),
|
||||
allowRoot: allowRoot,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// checkUserRequest handles normal privilege checking flow for specific usernames
|
||||
func (s *Server) checkUserRequest(ctx *privilegeCheckContext, req PrivilegeCheckRequest) PrivilegeCheckResult {
|
||||
if !ctx.currentUserPrivileged && isPlatformUnix() {
|
||||
log.Debugf("Unix non-privileged shortcut: falling back to current user %s for %s (requested: %s)",
|
||||
ctx.currentUser.Username, req.FeatureName, req.RequestedUsername)
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: true,
|
||||
User: ctx.currentUser,
|
||||
UsedFallback: true,
|
||||
RequiresUserSwitching: false,
|
||||
}
|
||||
}
|
||||
|
||||
resolvedUser, err := s.resolveRequestedUser(req.RequestedUsername)
|
||||
if err != nil {
|
||||
// Calculate if user switching would be required even if lookup failed
|
||||
needsUserSwitching := !isSameUser(req.RequestedUsername, ctx.currentUser.Username)
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: err,
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
needsUserSwitching := !isSameResolvedUser(resolvedUser, ctx.currentUser)
|
||||
|
||||
if isPrivilegedUsername(resolvedUser.Username) && !ctx.allowRoot {
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: &PrivilegedUserError{Username: resolvedUser.Username},
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
if needsUserSwitching && !req.FeatureSupportsUserSwitch {
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: false,
|
||||
Error: fmt.Errorf("%s: user switching not supported by this feature", req.FeatureName),
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
return PrivilegeCheckResult{
|
||||
Allowed: true,
|
||||
User: resolvedUser,
|
||||
RequiresUserSwitching: needsUserSwitching,
|
||||
}
|
||||
}
|
||||
|
||||
// resolveRequestedUser resolves a username to its canonical user identity
|
||||
func (s *Server) resolveRequestedUser(requestedUsername string) (*user.User, error) {
|
||||
if requestedUsername == "" {
|
||||
return getCurrentUser()
|
||||
}
|
||||
|
||||
if err := validateUsername(requestedUsername); err != nil {
|
||||
return nil, fmt.Errorf("invalid username %q: %w", requestedUsername, err)
|
||||
}
|
||||
|
||||
u, err := lookupUser(requestedUsername)
|
||||
if err != nil {
|
||||
return nil, &UserNotFoundError{Username: requestedUsername, Cause: err}
|
||||
}
|
||||
return u, nil
|
||||
}
|
||||
|
||||
// isSameResolvedUser compares two resolved user identities
|
||||
func isSameResolvedUser(user1, user2 *user.User) bool {
|
||||
if user1 == nil || user2 == nil {
|
||||
return user1 == user2
|
||||
}
|
||||
return user1.Uid == user2.Uid
|
||||
}
|
||||
|
||||
// privilegeCheckContext holds all context needed for privilege checking
|
||||
type privilegeCheckContext struct {
|
||||
currentUser *user.User
|
||||
currentUserPrivileged bool
|
||||
allowRoot bool
|
||||
}
|
||||
|
||||
// isSameUser checks if two usernames refer to the same user
|
||||
// SECURITY: This function must be conservative - it should only return true
|
||||
// when we're certain both usernames refer to the exact same user identity
|
||||
func isSameUser(requestedUsername, currentUsername string) bool {
|
||||
// Empty requested username means current user
|
||||
if requestedUsername == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Exact match (most common case)
|
||||
if getCurrentOS() == "windows" {
|
||||
if strings.EqualFold(requestedUsername, currentUsername) {
|
||||
return true
|
||||
}
|
||||
} else {
|
||||
if requestedUsername == currentUsername {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Windows domain resolution: only allow domain stripping when comparing
|
||||
// a bare username against the current user's domain-qualified name
|
||||
if getCurrentOS() == "windows" {
|
||||
return isWindowsSameUser(requestedUsername, currentUsername)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isWindowsSameUser handles Windows-specific user comparison with domain logic
|
||||
func isWindowsSameUser(requestedUsername, currentUsername string) bool {
|
||||
// Extract domain and username parts
|
||||
extractParts := func(name string) (domain, user string) {
|
||||
// Handle DOMAIN\username format
|
||||
if idx := strings.LastIndex(name, `\`); idx != -1 {
|
||||
return name[:idx], name[idx+1:]
|
||||
}
|
||||
// Handle user@domain.com format
|
||||
if idx := strings.Index(name, "@"); idx != -1 {
|
||||
return name[idx+1:], name[:idx]
|
||||
}
|
||||
// No domain specified - local machine
|
||||
return "", name
|
||||
}
|
||||
|
||||
reqDomain, reqUser := extractParts(requestedUsername)
|
||||
curDomain, curUser := extractParts(currentUsername)
|
||||
|
||||
// Case-insensitive username comparison
|
||||
if !strings.EqualFold(reqUser, curUser) {
|
||||
return false
|
||||
}
|
||||
|
||||
// If requested username has no domain, it refers to local machine user
|
||||
// Allow this to match the current user regardless of current user's domain
|
||||
if reqDomain == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
// If both have domains, they must match exactly (case-insensitive)
|
||||
return strings.EqualFold(reqDomain, curDomain)
|
||||
}
|
||||
|
||||
// SetAllowRootLogin configures root login access
|
||||
func (s *Server) SetAllowRootLogin(allow bool) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.allowRootLogin = allow
|
||||
}
|
||||
|
||||
// userNameLookup performs user lookup with root login permission check
|
||||
func (s *Server) userNameLookup(username string) (*user.User, error) {
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: username,
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: FeatureSSHLogin,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
return nil, result.Error
|
||||
}
|
||||
|
||||
return result.User, nil
|
||||
}
|
||||
|
||||
// userPrivilegeCheck performs user lookup with full privilege check result
|
||||
func (s *Server) userPrivilegeCheck(username string) (PrivilegeCheckResult, error) {
|
||||
result := s.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: username,
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: FeatureSSHLogin,
|
||||
})
|
||||
|
||||
if !result.Allowed {
|
||||
return result, result.Error
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// isPrivilegedUsername checks if the given username represents a privileged user across platforms.
|
||||
// On Unix: root
|
||||
// On Windows: Administrator, SYSTEM (case-insensitive)
|
||||
// Handles domain-qualified usernames like "DOMAIN\Administrator" or "user@domain.com"
|
||||
func isPrivilegedUsername(username string) bool {
|
||||
if getCurrentOS() != "windows" {
|
||||
return username == "root"
|
||||
}
|
||||
|
||||
bareUsername := username
|
||||
// Handle Windows domain format: DOMAIN\username
|
||||
if idx := strings.LastIndex(username, `\`); idx != -1 {
|
||||
bareUsername = username[idx+1:]
|
||||
}
|
||||
// Handle email-style format: username@domain.com
|
||||
if idx := strings.Index(bareUsername, "@"); idx != -1 {
|
||||
bareUsername = bareUsername[:idx]
|
||||
}
|
||||
|
||||
return isWindowsPrivilegedUser(bareUsername)
|
||||
}
|
||||
|
||||
// isWindowsPrivilegedUser checks if a bare username (domain already stripped) represents a Windows privileged account
|
||||
func isWindowsPrivilegedUser(bareUsername string) bool {
|
||||
// common privileged usernames (case insensitive)
|
||||
privilegedNames := []string{
|
||||
"administrator",
|
||||
"admin",
|
||||
"root",
|
||||
"system",
|
||||
"localsystem",
|
||||
"networkservice",
|
||||
"localservice",
|
||||
}
|
||||
|
||||
usernameLower := strings.ToLower(bareUsername)
|
||||
for _, privilegedName := range privilegedNames {
|
||||
if usernameLower == privilegedName {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// computer accounts (ending with $) are not privileged by themselves
|
||||
// They only gain privileges through group membership or specific SIDs
|
||||
|
||||
if targetUser, err := lookupUser(bareUsername); err == nil {
|
||||
return isWindowsPrivilegedSID(targetUser.Uid)
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isWindowsPrivilegedSID checks if a Windows SID represents a privileged account
|
||||
func isWindowsPrivilegedSID(sid string) bool {
|
||||
privilegedSIDs := []string{
|
||||
"S-1-5-18", // Local System (SYSTEM)
|
||||
"S-1-5-19", // Local Service (NT AUTHORITY\LOCAL SERVICE)
|
||||
"S-1-5-20", // Network Service (NT AUTHORITY\NETWORK SERVICE)
|
||||
"S-1-5-32-544", // Administrators group (BUILTIN\Administrators)
|
||||
"S-1-5-500", // Built-in Administrator account (local machine RID 500)
|
||||
}
|
||||
|
||||
for _, privilegedSID := range privilegedSIDs {
|
||||
if sid == privilegedSID {
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// Check for domain administrator accounts (RID 500 in any domain)
|
||||
// Format: S-1-5-21-domain-domain-domain-500
|
||||
// This is reliable as RID 500 is reserved for the domain Administrator account
|
||||
if strings.HasPrefix(sid, "S-1-5-21-") && strings.HasSuffix(sid, "-500") {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for other well-known privileged RIDs in domain contexts
|
||||
// RID 512 = Domain Admins group, RID 516 = Domain Controllers group
|
||||
if strings.HasPrefix(sid, "S-1-5-21-") {
|
||||
if strings.HasSuffix(sid, "-512") || // Domain Admins group
|
||||
strings.HasSuffix(sid, "-516") || // Domain Controllers group
|
||||
strings.HasSuffix(sid, "-519") { // Enterprise Admins group
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isCurrentProcessPrivileged checks if the current process is running with elevated privileges.
|
||||
// On Unix systems, this means running as root (UID 0).
|
||||
// On Windows, this means running as Administrator or SYSTEM.
|
||||
func isCurrentProcessPrivileged() bool {
|
||||
if getCurrentOS() == "windows" {
|
||||
return isWindowsElevated()
|
||||
}
|
||||
return getEuid() == 0
|
||||
}
|
||||
|
||||
// isWindowsElevated checks if the current process is running with elevated privileges on Windows
|
||||
func isWindowsElevated() bool {
|
||||
currentUser, err := getCurrentUser()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get current user for privilege check, assuming non-privileged: %v", err)
|
||||
return false
|
||||
}
|
||||
|
||||
if isWindowsPrivilegedSID(currentUser.Uid) {
|
||||
log.Debugf("Windows user switching supported: running as privileged SID %s", currentUser.Uid)
|
||||
return true
|
||||
}
|
||||
|
||||
if isPrivilegedUsername(currentUser.Username) {
|
||||
log.Debugf("Windows user switching supported: running as privileged username %s", currentUser.Username)
|
||||
return true
|
||||
}
|
||||
|
||||
log.Debugf("Windows user switching not supported: not running as privileged user (current: %s)", currentUser.Uid)
|
||||
return false
|
||||
}
|
||||
8
client/ssh/server/user_utils_js.go
Normal file
8
client/ssh/server/user_utils_js.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
// validateUsername is not supported on JS/WASM
|
||||
func validateUsername(_ string) error {
|
||||
return errNotSupported
|
||||
}
|
||||
908
client/ssh/server/user_utils_test.go
Normal file
908
client/ssh/server/user_utils_test.go
Normal file
@@ -0,0 +1,908 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// Test helper functions
|
||||
func createTestUser(username, uid, gid, homeDir string) *user.User {
|
||||
return &user.User{
|
||||
Uid: uid,
|
||||
Gid: gid,
|
||||
Username: username,
|
||||
Name: username,
|
||||
HomeDir: homeDir,
|
||||
}
|
||||
}
|
||||
|
||||
// Test dependency injection setup - injects platform dependencies to test real logic
|
||||
func setupTestDependencies(currentUser *user.User, currentUserErr error, os string, euid int, lookupUsers map[string]*user.User, lookupErrors map[string]error) func() {
|
||||
// Store originals
|
||||
originalGetCurrentUser := getCurrentUser
|
||||
originalLookupUser := lookupUser
|
||||
originalGetCurrentOS := getCurrentOS
|
||||
originalGetEuid := getEuid
|
||||
|
||||
// Reset caches to ensure clean test state
|
||||
|
||||
// Set test values - inject platform dependencies
|
||||
getCurrentUser = func() (*user.User, error) {
|
||||
return currentUser, currentUserErr
|
||||
}
|
||||
|
||||
lookupUser = func(username string) (*user.User, error) {
|
||||
if err, exists := lookupErrors[username]; exists {
|
||||
return nil, err
|
||||
}
|
||||
if userObj, exists := lookupUsers[username]; exists {
|
||||
return userObj, nil
|
||||
}
|
||||
return nil, errors.New("user: unknown user " + username)
|
||||
}
|
||||
|
||||
getCurrentOS = func() string {
|
||||
return os
|
||||
}
|
||||
|
||||
getEuid = func() int {
|
||||
return euid
|
||||
}
|
||||
|
||||
// Mock privilege detection based on the test user
|
||||
getIsProcessPrivileged = func() bool {
|
||||
if currentUser == nil {
|
||||
return false
|
||||
}
|
||||
// Check both username and SID for Windows systems
|
||||
if os == "windows" && isWindowsPrivilegedSID(currentUser.Uid) {
|
||||
return true
|
||||
}
|
||||
return isPrivilegedUsername(currentUser.Username)
|
||||
}
|
||||
|
||||
// Return cleanup function
|
||||
return func() {
|
||||
getCurrentUser = originalGetCurrentUser
|
||||
lookupUser = originalLookupUser
|
||||
getCurrentOS = originalGetCurrentOS
|
||||
getEuid = originalGetEuid
|
||||
|
||||
getIsProcessPrivileged = isCurrentProcessPrivileged
|
||||
|
||||
// Reset caches after test
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPrivileges_ComprehensiveMatrix(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
os string
|
||||
euid int
|
||||
currentUser *user.User
|
||||
requestedUsername string
|
||||
featureSupportsUserSwitch bool
|
||||
allowRoot bool
|
||||
lookupUsers map[string]*user.User
|
||||
expectedAllowed bool
|
||||
expectedRequiresSwitch bool
|
||||
}{
|
||||
{
|
||||
name: "linux_root_can_switch_to_alice",
|
||||
os: "linux",
|
||||
euid: 0, // Root process
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
},
|
||||
expectedAllowed: true,
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "linux_non_root_fallback_to_current_user",
|
||||
os: "linux",
|
||||
euid: 1000, // Non-root process
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "bob",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
expectedAllowed: true, // Should fallback to current user (alice)
|
||||
expectedRequiresSwitch: false, // Fallback means no actual switching
|
||||
},
|
||||
{
|
||||
name: "windows_admin_can_switch_to_alice",
|
||||
os: "windows",
|
||||
euid: 1000, // Irrelevant on Windows
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
},
|
||||
expectedAllowed: true,
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_non_admin_no_fallback_hard_failure",
|
||||
os: "windows",
|
||||
euid: 1000, // Irrelevant on Windows
|
||||
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
|
||||
requestedUsername: "bob",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"bob": createTestUser("bob", "S-1-5-21-123456789-123456789-123456789-1002", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\bob"),
|
||||
},
|
||||
expectedAllowed: true, // Let OS decide - deferred security check
|
||||
expectedRequiresSwitch: true, // Different user was requested
|
||||
},
|
||||
// Comprehensive test matrix: non-root linux with different allowRoot settings
|
||||
{
|
||||
name: "linux_non_root_request_root_allowRoot_false",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: true, // Fallback allows access regardless of root setting
|
||||
expectedRequiresSwitch: false, // Fallback case, no switching
|
||||
},
|
||||
{
|
||||
name: "linux_non_root_request_root_allowRoot_true",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
expectedAllowed: true, // Should fallback to alice (non-privileged process)
|
||||
expectedRequiresSwitch: false, // Fallback means no actual switching
|
||||
},
|
||||
// Windows admin test matrix
|
||||
{
|
||||
name: "windows_admin_request_root_allowRoot_false",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: false, // Root not allowed
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_admin_request_root_allowRoot_true",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // Windows user switching should work like Unix
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
// Windows non-admin test matrix
|
||||
{
|
||||
name: "windows_non_admin_request_root_allowRoot_false",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: false, // Root not allowed (allowRoot=false takes precedence)
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_system_account_allowRoot_false",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: false, // Root not allowed
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_system_account_allowRoot_true",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("NETBIRD\\WIN2K19-C2$", "S-1-5-18", "S-1-5-18", "C:\\Windows\\System32"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // SYSTEM can switch to root
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
{
|
||||
name: "windows_non_admin_request_root_allowRoot_true",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "root",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // Let OS decide - deferred security check
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Feature doesn't support user switching scenarios
|
||||
{
|
||||
name: "linux_root_feature_no_user_switching_same_user",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "root", // Same user
|
||||
featureSupportsUserSwitch: false,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
},
|
||||
expectedAllowed: true, // Same user should work regardless of feature support
|
||||
expectedRequiresSwitch: false,
|
||||
},
|
||||
{
|
||||
name: "linux_root_feature_no_user_switching_different_user",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: false, // Feature doesn't support switching
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
},
|
||||
expectedAllowed: false, // Should deny because feature doesn't support switching
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Empty username (current user) scenarios
|
||||
{
|
||||
name: "linux_non_root_current_user_empty_username",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
requestedUsername: "", // Empty = current user
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: true, // Current user should always work
|
||||
expectedRequiresSwitch: false,
|
||||
},
|
||||
{
|
||||
name: "linux_root_current_user_empty_username_root_not_allowed",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "", // Empty = current user (root)
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false, // Root not allowed
|
||||
expectedAllowed: false, // Should deny root even when it's current user
|
||||
expectedRequiresSwitch: false,
|
||||
},
|
||||
|
||||
// User not found scenarios
|
||||
{
|
||||
name: "linux_root_user_not_found",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
requestedUsername: "nonexistent",
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{}, // No users defined = user not found
|
||||
expectedAllowed: false, // Should fail due to user not found
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Windows feature doesn't support user switching
|
||||
{
|
||||
name: "windows_admin_feature_no_user_switching_different_user",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
requestedUsername: "alice",
|
||||
featureSupportsUserSwitch: false, // Feature doesn't support switching
|
||||
allowRoot: true,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
},
|
||||
expectedAllowed: false, // Should deny because feature doesn't support switching
|
||||
expectedRequiresSwitch: true,
|
||||
},
|
||||
|
||||
// Windows regular user scenarios (non-admin)
|
||||
{
|
||||
name: "windows_regular_user_same_user",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "alice", // Same user
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
lookupUsers: map[string]*user.User{
|
||||
"alice": createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
},
|
||||
expectedAllowed: true, // Regular user accessing themselves should work
|
||||
expectedRequiresSwitch: false, // No switching for same user
|
||||
},
|
||||
{
|
||||
name: "windows_regular_user_empty_username",
|
||||
os: "windows",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\Users\\alice"),
|
||||
requestedUsername: "", // Empty = current user
|
||||
featureSupportsUserSwitch: true,
|
||||
allowRoot: false,
|
||||
expectedAllowed: true, // Current user should always work
|
||||
expectedRequiresSwitch: false, // No switching for current user
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Inject platform dependencies to test real logic
|
||||
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, tt.lookupUsers, nil)
|
||||
defer cleanup()
|
||||
|
||||
server := &Server{allowRootLogin: tt.allowRoot}
|
||||
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: tt.requestedUsername,
|
||||
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
assert.Equal(t, tt.expectedAllowed, result.Allowed)
|
||||
assert.Equal(t, tt.expectedRequiresSwitch, result.RequiresUserSwitching)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsedFallback_MeansNoPrivilegeDropping(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Fallback mechanism is Unix-specific")
|
||||
}
|
||||
|
||||
// Create test scenario where fallback should occur
|
||||
server := &Server{allowRootLogin: true}
|
||||
|
||||
// Mock dependencies to simulate non-privileged user
|
||||
originalGetCurrentUser := getCurrentUser
|
||||
originalGetIsProcessPrivileged := getIsProcessPrivileged
|
||||
|
||||
defer func() {
|
||||
getCurrentUser = originalGetCurrentUser
|
||||
getIsProcessPrivileged = originalGetIsProcessPrivileged
|
||||
|
||||
}()
|
||||
|
||||
// Set up mocks for fallback scenario
|
||||
getCurrentUser = func() (*user.User, error) {
|
||||
return createTestUser("netbird", "1000", "1000", "/var/lib/netbird"), nil
|
||||
}
|
||||
getIsProcessPrivileged = func() bool { return false } // Non-privileged
|
||||
|
||||
// Request different user - should fallback
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: "alice",
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
// Verify fallback occurred
|
||||
assert.True(t, result.Allowed, "Should allow with fallback")
|
||||
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
|
||||
assert.Equal(t, "netbird", result.User.Username, "Should return current user")
|
||||
assert.False(t, result.RequiresUserSwitching, "Should not require switching when fallback is used")
|
||||
|
||||
// Key assertion: When UsedFallback is true, no privilege dropping should be needed
|
||||
// because all privilege checks have already been performed and we're using current user
|
||||
t.Logf("UsedFallback=true means: current user (%s) is the target, no privilege dropping needed",
|
||||
result.User.Username)
|
||||
}
|
||||
|
||||
func TestPrivilegedUsernameDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
platform string
|
||||
privileged bool
|
||||
}{
|
||||
// Unix/Linux tests
|
||||
{"unix_root", "root", "linux", true},
|
||||
{"unix_regular_user", "alice", "linux", false},
|
||||
{"unix_root_capital", "Root", "linux", false}, // Case-sensitive
|
||||
|
||||
// Windows tests
|
||||
{"windows_administrator", "Administrator", "windows", true},
|
||||
{"windows_system", "SYSTEM", "windows", true},
|
||||
{"windows_admin", "admin", "windows", true},
|
||||
{"windows_admin_lowercase", "administrator", "windows", true}, // Case-insensitive
|
||||
{"windows_domain_admin", "DOMAIN\\Administrator", "windows", true},
|
||||
{"windows_email_admin", "admin@domain.com", "windows", true},
|
||||
{"windows_regular_user", "alice", "windows", false},
|
||||
{"windows_domain_user", "DOMAIN\\alice", "windows", false},
|
||||
{"windows_localsystem", "localsystem", "windows", true},
|
||||
{"windows_networkservice", "networkservice", "windows", true},
|
||||
{"windows_localservice", "localservice", "windows", true},
|
||||
|
||||
// Computer accounts (these depend on current user context in real implementation)
|
||||
{"windows_computer_account", "WIN2K19-C2$", "windows", false}, // Computer account by itself not privileged
|
||||
{"windows_domain_computer", "DOMAIN\\COMPUTER$", "windows", false}, // Domain computer account
|
||||
|
||||
// Cross-platform
|
||||
{"root_on_windows", "root", "windows", true}, // Root should be privileged everywhere
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock the platform for this test
|
||||
cleanup := setupTestDependencies(nil, nil, tt.platform, 1000, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
result := isPrivilegedUsername(tt.username)
|
||||
assert.Equal(t, tt.privileged, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestWindowsPrivilegedSIDDetection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
sid string
|
||||
privileged bool
|
||||
description string
|
||||
}{
|
||||
// Well-known system accounts
|
||||
{"system_account", "S-1-5-18", true, "Local System (SYSTEM)"},
|
||||
{"local_service", "S-1-5-19", true, "Local Service"},
|
||||
{"network_service", "S-1-5-20", true, "Network Service"},
|
||||
{"administrators_group", "S-1-5-32-544", true, "Administrators group"},
|
||||
{"builtin_administrator", "S-1-5-500", true, "Built-in Administrator"},
|
||||
|
||||
// Domain accounts
|
||||
{"domain_administrator", "S-1-5-21-1234567890-1234567890-1234567890-500", true, "Domain Administrator (RID 500)"},
|
||||
{"domain_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-512", true, "Domain Admins group"},
|
||||
{"domain_controllers_group", "S-1-5-21-1234567890-1234567890-1234567890-516", true, "Domain Controllers group"},
|
||||
{"enterprise_admins_group", "S-1-5-21-1234567890-1234567890-1234567890-519", true, "Enterprise Admins group"},
|
||||
|
||||
// Regular users
|
||||
{"regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1001", false, "Regular domain user"},
|
||||
{"another_regular_user", "S-1-5-21-1234567890-1234567890-1234567890-1234", false, "Another regular user"},
|
||||
{"local_user", "S-1-5-21-1234567890-1234567890-1234567890-1000", false, "Local regular user"},
|
||||
|
||||
// Groups that are not privileged
|
||||
{"domain_users", "S-1-5-21-1234567890-1234567890-1234567890-513", false, "Domain Users group"},
|
||||
{"power_users", "S-1-5-32-547", false, "Power Users group"},
|
||||
|
||||
// Invalid SIDs
|
||||
{"malformed_sid", "S-1-5-invalid", false, "Malformed SID"},
|
||||
{"empty_sid", "", false, "Empty SID"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := isWindowsPrivilegedSID(tt.sid)
|
||||
assert.Equal(t, tt.privileged, result, "Failed for %s: %s", tt.description, tt.sid)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestIsSameUser(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
user1 string
|
||||
user2 string
|
||||
os string
|
||||
expected bool
|
||||
}{
|
||||
// Basic cases
|
||||
{"same_username", "alice", "alice", "linux", true},
|
||||
{"different_username", "alice", "bob", "linux", false},
|
||||
|
||||
// Linux (no domain processing)
|
||||
{"linux_domain_vs_bare", "DOMAIN\\alice", "alice", "linux", false},
|
||||
{"linux_email_vs_bare", "alice@domain.com", "alice", "linux", false},
|
||||
{"linux_same_literal", "DOMAIN\\alice", "DOMAIN\\alice", "linux", true},
|
||||
|
||||
// Windows (with domain processing) - Note: parameter order is (requested, current, os, expected)
|
||||
{"windows_domain_vs_bare", "alice", "DOMAIN\\alice", "windows", true}, // bare username matches domain current user
|
||||
{"windows_email_vs_bare", "alice", "alice@domain.com", "windows", true}, // bare username matches email current user
|
||||
{"windows_different_domains_same_user", "DOMAIN1\\alice", "DOMAIN2\\alice", "windows", false}, // SECURITY: different domains = different users
|
||||
{"windows_case_insensitive", "Alice", "alice", "windows", true},
|
||||
{"windows_different_users", "DOMAIN\\alice", "DOMAIN\\bob", "windows", false},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Set up OS mock
|
||||
cleanup := setupTestDependencies(nil, nil, tt.os, 1000, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
result := isSameUser(tt.user1, tt.user2)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsernameValidation_Unix(t *testing.T) {
|
||||
if runtime.GOOS == "windows" {
|
||||
t.Skip("Unix-specific username validation tests")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid usernames (Unix/POSIX)
|
||||
{"valid_alphanumeric", "user123", false, ""},
|
||||
{"valid_with_dots", "user.name", false, ""},
|
||||
{"valid_with_hyphens", "user-name", false, ""},
|
||||
{"valid_with_underscores", "user_name", false, ""},
|
||||
{"valid_uppercase", "UserName", false, ""},
|
||||
{"valid_starting_with_digit", "123user", false, ""},
|
||||
{"valid_starting_with_dot", ".hidden", false, ""},
|
||||
|
||||
// Invalid usernames (Unix/POSIX)
|
||||
{"empty_username", "", true, "username cannot be empty"},
|
||||
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
|
||||
{"username_starting_with_hyphen", "-user", true, "invalid characters"}, // POSIX restriction
|
||||
{"username_with_spaces", "user name", true, "invalid characters"},
|
||||
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
|
||||
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
|
||||
{"username_with_pipe", "user|rm", true, "invalid characters"},
|
||||
{"username_with_ampersand", "user&rm", true, "invalid characters"},
|
||||
{"username_with_quotes", "user\"name", true, "invalid characters"},
|
||||
{"username_with_newline", "user\nname", true, "invalid characters"},
|
||||
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
|
||||
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
|
||||
{"username_with_at_symbol", "user@domain", true, "invalid characters"}, // Not allowed in bare Unix usernames
|
||||
{"username_with_backslash", "user\\name", true, "invalid characters"}, // Not allowed in Unix usernames
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateUsername(tt.username)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err, "Should reject invalid username")
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err, "Should accept valid username")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUsernameValidation_Windows(t *testing.T) {
|
||||
if runtime.GOOS != "windows" {
|
||||
t.Skip("Windows-specific username validation tests")
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
username string
|
||||
wantErr bool
|
||||
errMsg string
|
||||
}{
|
||||
// Valid usernames (Windows)
|
||||
{"valid_alphanumeric", "user123", false, ""},
|
||||
{"valid_with_dots", "user.name", false, ""},
|
||||
{"valid_with_hyphens", "user-name", false, ""},
|
||||
{"valid_with_underscores", "user_name", false, ""},
|
||||
{"valid_uppercase", "UserName", false, ""},
|
||||
{"valid_starting_with_digit", "123user", false, ""},
|
||||
{"valid_starting_with_dot", ".hidden", false, ""},
|
||||
{"valid_starting_with_hyphen", "-user", false, ""}, // Windows allows this
|
||||
{"valid_domain_username", "DOMAIN\\user", false, ""}, // Windows domain format
|
||||
{"valid_email_username", "user@domain.com", false, ""}, // Windows email format
|
||||
{"valid_machine_username", "MACHINE\\user", false, ""}, // Windows machine format
|
||||
|
||||
// Invalid usernames (Windows)
|
||||
{"empty_username", "", true, "username cannot be empty"},
|
||||
{"username_too_long", "thisusernameiswaytoolongandexceedsthe32characterlimit", true, "username too long"},
|
||||
{"username_with_spaces", "user name", true, "invalid characters"},
|
||||
{"username_with_shell_metacharacters", "user;rm", true, "invalid characters"},
|
||||
{"username_with_command_injection", "user`rm -rf /`", true, "invalid characters"},
|
||||
{"username_with_pipe", "user|rm", true, "invalid characters"},
|
||||
{"username_with_ampersand", "user&rm", true, "invalid characters"},
|
||||
{"username_with_quotes", "user\"name", true, "invalid characters"},
|
||||
{"username_with_newline", "user\nname", true, "invalid characters"},
|
||||
{"username_with_brackets", "user[name]", true, "invalid characters"},
|
||||
{"username_with_colon", "user:name", true, "invalid characters"},
|
||||
{"username_with_semicolon", "user;name", true, "invalid characters"},
|
||||
{"username_with_equals", "user=name", true, "invalid characters"},
|
||||
{"username_with_comma", "user,name", true, "invalid characters"},
|
||||
{"username_with_plus", "user+name", true, "invalid characters"},
|
||||
{"username_with_asterisk", "user*name", true, "invalid characters"},
|
||||
{"username_with_question", "user?name", true, "invalid characters"},
|
||||
{"username_with_angles", "user<name>", true, "invalid characters"},
|
||||
{"reserved_dot", ".", true, "cannot be '.' or '..'"},
|
||||
{"reserved_dotdot", "..", true, "cannot be '.' or '..'"},
|
||||
{"username_ending_with_period", "user.", true, "cannot end with a period"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := validateUsername(tt.username)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err, "Should reject invalid username")
|
||||
if tt.errMsg != "" {
|
||||
assert.Contains(t, err.Error(), tt.errMsg, "Error message should contain expected text")
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err, "Should accept valid username")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test real-world integration scenarios with actual platform capabilities
|
||||
func TestCheckPrivileges_RealWorldScenarios(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
feature string
|
||||
featureSupportsUserSwitch bool
|
||||
requestedUsername string
|
||||
allowRoot bool
|
||||
expectedBehaviorPattern string
|
||||
}{
|
||||
{"SSH_login_current_user", "SSH login", true, "", true, "should_allow_current_user"},
|
||||
{"SFTP_current_user", "SFTP", true, "", true, "should_allow_current_user"},
|
||||
{"port_forwarding_current_user", "port forwarding", false, "", true, "should_allow_current_user"},
|
||||
{"SSH_login_root_not_allowed", "SSH login", true, "root", false, "should_deny_root"},
|
||||
{"port_forwarding_different_user", "port forwarding", false, "differentuser", true, "should_deny_switching"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Mock privileged environment to ensure consistent test behavior across environments
|
||||
cleanup := setupTestDependencies(
|
||||
createTestUser("root", "0", "0", "/root"), // Running as root
|
||||
nil,
|
||||
runtime.GOOS,
|
||||
0, // euid 0 (root)
|
||||
map[string]*user.User{
|
||||
"root": createTestUser("root", "0", "0", "/root"),
|
||||
"differentuser": createTestUser("differentuser", "1000", "1000", "/home/differentuser"),
|
||||
},
|
||||
nil,
|
||||
)
|
||||
defer cleanup()
|
||||
|
||||
server := &Server{allowRootLogin: tt.allowRoot}
|
||||
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: tt.requestedUsername,
|
||||
FeatureSupportsUserSwitch: tt.featureSupportsUserSwitch,
|
||||
FeatureName: tt.feature,
|
||||
})
|
||||
|
||||
switch tt.expectedBehaviorPattern {
|
||||
case "should_allow_current_user":
|
||||
assert.True(t, result.Allowed, "Should allow current user access")
|
||||
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
|
||||
case "should_deny_root":
|
||||
assert.False(t, result.Allowed, "Should deny root when not allowed")
|
||||
assert.Contains(t, result.Error.Error(), "root", "Should mention root in error")
|
||||
case "should_deny_switching":
|
||||
assert.False(t, result.Allowed, "Should deny when feature doesn't support switching")
|
||||
assert.Contains(t, result.Error.Error(), "user switching not supported", "Should mention switching in error")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// Test with actual platform capabilities - no mocking
|
||||
func TestCheckPrivileges_ActualPlatform(t *testing.T) {
|
||||
// This test uses the REAL platform capabilities
|
||||
server := &Server{allowRootLogin: true}
|
||||
|
||||
// Test current user access - should always work
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: "", // Current user
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
assert.True(t, result.Allowed, "Current user should always be allowed")
|
||||
assert.False(t, result.RequiresUserSwitching, "Current user should not require switching")
|
||||
assert.NotNil(t, result.User, "Should return current user")
|
||||
|
||||
// Test user switching capability based on actual platform
|
||||
actualIsPrivileged := isCurrentProcessPrivileged() // REAL check
|
||||
actualOS := runtime.GOOS // REAL check
|
||||
|
||||
t.Logf("Platform capabilities: OS=%s, isPrivileged=%v, supportsUserSwitching=%v",
|
||||
actualOS, actualIsPrivileged, actualIsPrivileged)
|
||||
|
||||
// Test requesting different user
|
||||
result = server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: "nonexistentuser",
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
switch {
|
||||
case actualOS == "windows":
|
||||
// Windows supports user switching but should fail on nonexistent user
|
||||
assert.False(t, result.Allowed, "Windows should deny nonexistent user")
|
||||
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
|
||||
assert.Contains(t, result.Error.Error(), "not found",
|
||||
"Should indicate user not found")
|
||||
case !actualIsPrivileged:
|
||||
// Non-privileged Unix processes should fallback to current user
|
||||
assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user")
|
||||
assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens")
|
||||
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
|
||||
assert.NotNil(t, result.User, "Should return current user")
|
||||
default:
|
||||
// Privileged Unix processes should attempt user lookup
|
||||
assert.False(t, result.Allowed, "Should fail due to nonexistent user")
|
||||
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
|
||||
assert.Contains(t, result.Error.Error(), "nonexistentuser",
|
||||
"Should indicate user not found")
|
||||
}
|
||||
}
|
||||
|
||||
// Test platform detection logic with dependency injection
|
||||
func TestPlatformLogic_DependencyInjection(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
os string
|
||||
euid int
|
||||
currentUser *user.User
|
||||
expectedIsProcessPrivileged bool
|
||||
expectedSupportsUserSwitching bool
|
||||
}{
|
||||
{
|
||||
name: "linux_root_process",
|
||||
os: "linux",
|
||||
euid: 0,
|
||||
currentUser: createTestUser("root", "0", "0", "/root"),
|
||||
expectedIsProcessPrivileged: true,
|
||||
expectedSupportsUserSwitching: true,
|
||||
},
|
||||
{
|
||||
name: "linux_non_root_process",
|
||||
os: "linux",
|
||||
euid: 1000,
|
||||
currentUser: createTestUser("alice", "1000", "1000", "/home/alice"),
|
||||
expectedIsProcessPrivileged: false,
|
||||
expectedSupportsUserSwitching: false,
|
||||
},
|
||||
{
|
||||
name: "windows_admin_process",
|
||||
os: "windows",
|
||||
euid: 1000, // euid ignored on Windows
|
||||
currentUser: createTestUser("Administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\Users\\Administrator"),
|
||||
expectedIsProcessPrivileged: true,
|
||||
expectedSupportsUserSwitching: true, // Windows supports user switching when privileged
|
||||
},
|
||||
{
|
||||
name: "windows_regular_process",
|
||||
os: "windows",
|
||||
euid: 1000, // euid ignored on Windows
|
||||
currentUser: createTestUser("alice", "1001", "1001", "C:\\Users\\alice"),
|
||||
expectedIsProcessPrivileged: false,
|
||||
expectedSupportsUserSwitching: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Inject platform dependencies and test REAL logic
|
||||
cleanup := setupTestDependencies(tt.currentUser, nil, tt.os, tt.euid, nil, nil)
|
||||
defer cleanup()
|
||||
|
||||
// Test the actual functions with injected dependencies
|
||||
actualIsPrivileged := isCurrentProcessPrivileged()
|
||||
actualSupportsUserSwitching := actualIsPrivileged
|
||||
|
||||
assert.Equal(t, tt.expectedIsProcessPrivileged, actualIsPrivileged,
|
||||
"isCurrentProcessPrivileged() result mismatch")
|
||||
assert.Equal(t, tt.expectedSupportsUserSwitching, actualSupportsUserSwitching,
|
||||
"supportsUserSwitching() result mismatch")
|
||||
|
||||
t.Logf("Platform: %s, EUID: %d, User: %s", tt.os, tt.euid, tt.currentUser.Username)
|
||||
t.Logf("Results: isPrivileged=%v, supportsUserSwitching=%v",
|
||||
actualIsPrivileged, actualSupportsUserSwitching)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCheckPrivileges_WindowsElevatedUserSwitching(t *testing.T) {
|
||||
// Test Windows elevated user switching scenarios with simplified privilege logic
|
||||
tests := []struct {
|
||||
name string
|
||||
currentUser *user.User
|
||||
requestedUsername string
|
||||
allowRoot bool
|
||||
expectedAllowed bool
|
||||
expectedErrorContains string
|
||||
}{
|
||||
{
|
||||
name: "windows_admin_can_switch_to_alice",
|
||||
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
|
||||
requestedUsername: "alice",
|
||||
allowRoot: true,
|
||||
expectedAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "windows_non_admin_can_try_switch",
|
||||
currentUser: createTestUser("alice", "S-1-5-21-123456789-123456789-123456789-1001", "S-1-5-21-123456789-123456789-123456789-513", "C:\\\\Users\\\\alice"),
|
||||
requestedUsername: "bob",
|
||||
allowRoot: true,
|
||||
expectedAllowed: true, // Privilege check allows it, OS will reject during execution
|
||||
},
|
||||
{
|
||||
name: "windows_system_can_switch_to_alice",
|
||||
currentUser: createTestUser("SYSTEM", "S-1-5-18", "S-1-5-18", "C:\\\\Windows\\\\system32\\\\config\\\\systemprofile"),
|
||||
requestedUsername: "alice",
|
||||
allowRoot: true,
|
||||
expectedAllowed: true,
|
||||
},
|
||||
{
|
||||
name: "windows_admin_root_not_allowed",
|
||||
currentUser: createTestUser("administrator", "S-1-5-21-123456789-123456789-123456789-500", "S-1-5-32-544", "C:\\\\Users\\\\Administrator"),
|
||||
requestedUsername: "root",
|
||||
allowRoot: false,
|
||||
expectedAllowed: false,
|
||||
expectedErrorContains: "privileged user login is disabled",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Setup test dependencies with Windows OS and specified privileges
|
||||
lookupUsers := map[string]*user.User{
|
||||
tt.requestedUsername: createTestUser(tt.requestedUsername, "1002", "1002", "C:\\\\Users\\\\"+tt.requestedUsername),
|
||||
}
|
||||
cleanup := setupTestDependencies(tt.currentUser, nil, "windows", 1000, lookupUsers, nil)
|
||||
defer cleanup()
|
||||
|
||||
server := &Server{allowRootLogin: tt.allowRoot}
|
||||
|
||||
result := server.CheckPrivileges(PrivilegeCheckRequest{
|
||||
RequestedUsername: tt.requestedUsername,
|
||||
FeatureSupportsUserSwitch: true,
|
||||
FeatureName: "SSH login",
|
||||
})
|
||||
|
||||
assert.Equal(t, tt.expectedAllowed, result.Allowed,
|
||||
"Privilege check result should match expected for %s", tt.name)
|
||||
|
||||
if !tt.expectedAllowed && tt.expectedErrorContains != "" {
|
||||
assert.NotNil(t, result.Error, "Should have error when not allowed")
|
||||
assert.Contains(t, result.Error.Error(), tt.expectedErrorContains,
|
||||
"Error should contain expected message")
|
||||
}
|
||||
|
||||
if tt.expectedAllowed && tt.requestedUsername != "" && tt.currentUser.Username != tt.requestedUsername {
|
||||
assert.True(t, result.RequiresUserSwitching, "Should require user switching for different user")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
8
client/ssh/server/userswitching_js.go
Normal file
8
client/ssh/server/userswitching_js.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build js
|
||||
|
||||
package server
|
||||
|
||||
// enableUserSwitching is not supported on JS/WASM
|
||||
func enableUserSwitching() error {
|
||||
return errNotSupported
|
||||
}
|
||||
233
client/ssh/server/userswitching_unix.go
Normal file
233
client/ssh/server/userswitching_unix.go
Normal file
@@ -0,0 +1,233 @@
|
||||
//go:build unix
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strconv"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
// POSIX portable filename character set regex: [a-zA-Z0-9._-]
|
||||
// First character cannot be hyphen (POSIX requirement)
|
||||
var posixUsernameRegex = regexp.MustCompile(`^[a-zA-Z0-9._][a-zA-Z0-9._-]*$`)
|
||||
|
||||
// validateUsername validates that a username conforms to POSIX standards with security considerations
|
||||
func validateUsername(username string) error {
|
||||
if username == "" {
|
||||
return errors.New("username cannot be empty")
|
||||
}
|
||||
|
||||
// POSIX allows up to 256 characters, but practical limit is 32 for compatibility
|
||||
if len(username) > 32 {
|
||||
return errors.New("username too long (max 32 characters)")
|
||||
}
|
||||
|
||||
if !posixUsernameRegex.MatchString(username) {
|
||||
return errors.New("username contains invalid characters (must match POSIX portable filename character set)")
|
||||
}
|
||||
|
||||
if username == "." || username == ".." {
|
||||
return fmt.Errorf("username cannot be '.' or '..'")
|
||||
}
|
||||
|
||||
// Warn if username is fully numeric (can cause issues with UID/username ambiguity)
|
||||
if isFullyNumeric(username) {
|
||||
log.Warnf("fully numeric username '%s' may cause issues with some commands", username)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// isFullyNumeric checks if username contains only digits
|
||||
func isFullyNumeric(username string) bool {
|
||||
for _, char := range username {
|
||||
if char < '0' || char > '9' {
|
||||
return false
|
||||
}
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// createPtyLoginCommand creates a Pty command using login for privileged processes
|
||||
func (s *Server) createPtyLoginCommand(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
|
||||
loginPath, args, err := s.getLoginCmd(localUser.Username, session.RemoteAddr())
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get login command: %w", err)
|
||||
}
|
||||
|
||||
execCmd := exec.CommandContext(session.Context(), loginPath, args...)
|
||||
execCmd.Dir = localUser.HomeDir
|
||||
execCmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||
|
||||
return execCmd, nil
|
||||
}
|
||||
|
||||
// getLoginCmd returns the login command and args for privileged Pty user switching
|
||||
func (s *Server) getLoginCmd(username string, remoteAddr net.Addr) (string, []string, error) {
|
||||
loginPath, err := exec.LookPath("login")
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("login command not available: %w", err)
|
||||
}
|
||||
|
||||
addrPort, err := netip.ParseAddrPort(remoteAddr.String())
|
||||
if err != nil {
|
||||
return "", nil, fmt.Errorf("parse remote address: %w", err)
|
||||
}
|
||||
|
||||
switch runtime.GOOS {
|
||||
case "linux":
|
||||
// Special handling for Arch Linux without /etc/pam.d/remote
|
||||
if s.fileExists("/etc/arch-release") && !s.fileExists("/etc/pam.d/remote") {
|
||||
return loginPath, []string{"-f", username, "-p"}, nil
|
||||
}
|
||||
return loginPath, []string{"-f", username, "-h", addrPort.Addr().String(), "-p"}, nil
|
||||
case "darwin", "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||
return loginPath, []string{"-fp", "-h", addrPort.Addr().String(), username}, nil
|
||||
default:
|
||||
return "", nil, fmt.Errorf("unsupported Unix platform for login command: %s", runtime.GOOS)
|
||||
}
|
||||
}
|
||||
|
||||
// fileExists checks if a file exists (helper for login command logic)
|
||||
func (s *Server) fileExists(path string) bool {
|
||||
_, err := os.Stat(path)
|
||||
return err == nil
|
||||
}
|
||||
|
||||
// parseUserCredentials extracts numeric UID, GID, and supplementary groups
|
||||
func (s *Server) parseUserCredentials(localUser *user.User) (uint32, uint32, []uint32, error) {
|
||||
uid64, err := strconv.ParseUint(localUser.Uid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, nil, fmt.Errorf("invalid UID %s: %w", localUser.Uid, err)
|
||||
}
|
||||
uid := uint32(uid64)
|
||||
|
||||
gid64, err := strconv.ParseUint(localUser.Gid, 10, 32)
|
||||
if err != nil {
|
||||
return 0, 0, nil, fmt.Errorf("invalid GID %s: %w", localUser.Gid, err)
|
||||
}
|
||||
gid := uint32(gid64)
|
||||
|
||||
groups, err := s.getSupplementaryGroups(localUser.Username)
|
||||
if err != nil {
|
||||
log.Warnf("failed to get supplementary groups for user %s: %v", localUser.Username, err)
|
||||
groups = []uint32{gid}
|
||||
}
|
||||
|
||||
return uid, gid, groups, nil
|
||||
}
|
||||
|
||||
// getSupplementaryGroups retrieves supplementary group IDs for a user
|
||||
func (s *Server) getSupplementaryGroups(username string) ([]uint32, error) {
|
||||
u, err := user.Lookup(username)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("lookup user %s: %w", username, err)
|
||||
}
|
||||
|
||||
groupIDStrings, err := u.GroupIds()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get group IDs for user %s: %w", username, err)
|
||||
}
|
||||
|
||||
groups := make([]uint32, len(groupIDStrings))
|
||||
for i, gidStr := range groupIDStrings {
|
||||
gid64, err := strconv.ParseUint(gidStr, 10, 32)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("invalid group ID %s for user %s: %w", gidStr, username, err)
|
||||
}
|
||||
groups[i] = uint32(gid64)
|
||||
}
|
||||
|
||||
return groups, nil
|
||||
}
|
||||
|
||||
// createExecutorCommand creates a command that spawns netbird ssh exec for privilege dropping.
|
||||
// Returns the command and a cleanup function (no-op on Unix).
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
if err := validateUsername(localUser.Username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", localUser.Username, err)
|
||||
}
|
||||
|
||||
uid, gid, groups, err := s.parseUserCredentials(localUser)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse user credentials: %w", err)
|
||||
}
|
||||
privilegeDropper := NewPrivilegeDropper()
|
||||
config := ExecutorConfig{
|
||||
UID: uid,
|
||||
GID: gid,
|
||||
Groups: groups,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: getUserShell(localUser.Uid),
|
||||
Command: session.RawCommand(),
|
||||
PTY: hasPty,
|
||||
}
|
||||
|
||||
cmd, err := privilegeDropper.CreateExecutorCommand(session.Context(), config)
|
||||
return cmd, func() {}, err
|
||||
}
|
||||
|
||||
// enableUserSwitching is a no-op on Unix systems
|
||||
func enableUserSwitching() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createPtyCommand creates the exec.Cmd for Pty execution respecting privilege check results
|
||||
func (s *Server) createPtyCommand(privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
|
||||
localUser := privilegeResult.User
|
||||
if localUser == nil {
|
||||
return nil, errors.New("no user in privilege result")
|
||||
}
|
||||
|
||||
if privilegeResult.UsedFallback {
|
||||
return s.createDirectPtyCommand(session, localUser, ptyReq), nil
|
||||
}
|
||||
|
||||
return s.createPtyLoginCommand(localUser, ptyReq, session)
|
||||
}
|
||||
|
||||
// createDirectPtyCommand creates a direct Pty command without privilege dropping
|
||||
func (s *Server) createDirectPtyCommand(session ssh.Session, localUser *user.User, ptyReq ssh.Pty) *exec.Cmd {
|
||||
log.Debugf("creating direct Pty command for user %s (no user switching needed)", localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
args := s.getShellCommandArgs(shell, session.RawCommand())
|
||||
|
||||
cmd := exec.CommandContext(session.Context(), args[0], args[1:]...)
|
||||
cmd.Dir = localUser.HomeDir
|
||||
cmd.Env = s.preparePtyEnv(localUser, ptyReq, session)
|
||||
|
||||
return cmd
|
||||
}
|
||||
|
||||
// preparePtyEnv prepares environment variables for Pty execution
|
||||
func (s *Server) preparePtyEnv(localUser *user.User, ptyReq ssh.Pty, session ssh.Session) []string {
|
||||
termType := ptyReq.Term
|
||||
if termType == "" {
|
||||
termType = "xterm-256color"
|
||||
}
|
||||
|
||||
env := prepareUserEnv(localUser, getUserShell(localUser.Uid))
|
||||
env = append(env, prepareSSHEnv(session)...)
|
||||
env = append(env, fmt.Sprintf("TERM=%s", termType))
|
||||
|
||||
for _, v := range session.Environ() {
|
||||
if acceptEnv(v) {
|
||||
env = append(env, v)
|
||||
}
|
||||
}
|
||||
return env
|
||||
}
|
||||
274
client/ssh/server/userswitching_windows.go
Normal file
274
client/ssh/server/userswitching_windows.go
Normal file
@@ -0,0 +1,274 @@
|
||||
//go:build windows
|
||||
|
||||
package server
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"strings"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
// validateUsername validates Windows usernames according to SAM Account Name rules
|
||||
func validateUsername(username string) error {
|
||||
if username == "" {
|
||||
return fmt.Errorf("username cannot be empty")
|
||||
}
|
||||
|
||||
usernameToValidate := extractUsernameFromDomain(username)
|
||||
|
||||
if err := validateUsernameLength(usernameToValidate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateUsernameCharacters(usernameToValidate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := validateUsernameFormat(usernameToValidate); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// extractUsernameFromDomain extracts the username part from domain\username or username@domain format
|
||||
func extractUsernameFromDomain(username string) string {
|
||||
if idx := strings.LastIndex(username, `\`); idx != -1 {
|
||||
return username[idx+1:]
|
||||
}
|
||||
if idx := strings.Index(username, "@"); idx != -1 {
|
||||
return username[:idx]
|
||||
}
|
||||
return username
|
||||
}
|
||||
|
||||
// validateUsernameLength checks if username length is within Windows limits
|
||||
func validateUsernameLength(username string) error {
|
||||
if len(username) > 20 {
|
||||
return fmt.Errorf("username too long (max 20 characters for Windows)")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUsernameCharacters checks for invalid characters in Windows usernames
|
||||
func validateUsernameCharacters(username string) error {
|
||||
invalidChars := []rune{'"', '/', '[', ']', ':', ';', '|', '=', ',', '+', '*', '?', '<', '>', ' ', '`', '&', '\n'}
|
||||
for _, char := range username {
|
||||
for _, invalid := range invalidChars {
|
||||
if char == invalid {
|
||||
return fmt.Errorf("username contains invalid characters")
|
||||
}
|
||||
}
|
||||
if char < 32 || char == 127 {
|
||||
return fmt.Errorf("username contains control characters")
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// validateUsernameFormat checks for invalid username formats and patterns
|
||||
func validateUsernameFormat(username string) error {
|
||||
if username == "." || username == ".." {
|
||||
return fmt.Errorf("username cannot be '.' or '..'")
|
||||
}
|
||||
|
||||
if strings.HasSuffix(username, ".") {
|
||||
return fmt.Errorf("username cannot end with a period")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// createExecutorCommand creates a command using Windows executor for privilege dropping.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createExecutorCommand(session ssh.Session, localUser *user.User, hasPty bool) (*exec.Cmd, func(), error) {
|
||||
log.Debugf("creating Windows executor command for user %s (Pty: %v)", localUser.Username, hasPty)
|
||||
|
||||
username, _ := s.parseUsername(localUser.Username)
|
||||
if err := validateUsername(username); err != nil {
|
||||
return nil, nil, fmt.Errorf("invalid username %q: %w", username, err)
|
||||
}
|
||||
|
||||
return s.createUserSwitchCommand(localUser, session, hasPty)
|
||||
}
|
||||
|
||||
// createUserSwitchCommand creates a command with Windows user switching.
|
||||
// Returns the command and a cleanup function that must be called after starting the process.
|
||||
func (s *Server) createUserSwitchCommand(localUser *user.User, session ssh.Session, interactive bool) (*exec.Cmd, func(), error) {
|
||||
username, domain := s.parseUsername(localUser.Username)
|
||||
|
||||
shell := getUserShell(localUser.Uid)
|
||||
|
||||
rawCmd := session.RawCommand()
|
||||
var command string
|
||||
if rawCmd != "" {
|
||||
command = rawCmd
|
||||
}
|
||||
|
||||
config := WindowsExecutorConfig{
|
||||
Username: username,
|
||||
Domain: domain,
|
||||
WorkingDir: localUser.HomeDir,
|
||||
Shell: shell,
|
||||
Command: command,
|
||||
Interactive: interactive || (rawCmd == ""),
|
||||
}
|
||||
|
||||
dropper := NewPrivilegeDropper()
|
||||
cmd, token, err := dropper.CreateWindowsExecutorCommand(session.Context(), config)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
|
||||
cleanup := func() {
|
||||
if token != 0 {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return cmd, cleanup, nil
|
||||
}
|
||||
|
||||
// parseUsername extracts username and domain from a Windows username
|
||||
func (s *Server) parseUsername(fullUsername string) (username, domain string) {
|
||||
// Handle DOMAIN\username format
|
||||
if idx := strings.LastIndex(fullUsername, `\`); idx != -1 {
|
||||
domain = fullUsername[:idx]
|
||||
username = fullUsername[idx+1:]
|
||||
return username, domain
|
||||
}
|
||||
|
||||
// Handle username@domain format
|
||||
if username, domain, ok := strings.Cut(fullUsername, "@"); ok {
|
||||
return username, domain
|
||||
}
|
||||
|
||||
// Local user (no domain)
|
||||
return fullUsername, "."
|
||||
}
|
||||
|
||||
// hasPrivilege checks if the current process has a specific privilege
|
||||
func hasPrivilege(token windows.Handle, privilegeName string) (bool, error) {
|
||||
var luid windows.LUID
|
||||
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
|
||||
return false, fmt.Errorf("lookup privilege value: %w", err)
|
||||
}
|
||||
|
||||
var returnLength uint32
|
||||
err := windows.GetTokenInformation(
|
||||
windows.Token(token),
|
||||
windows.TokenPrivileges,
|
||||
nil, // null buffer to get size
|
||||
0,
|
||||
&returnLength,
|
||||
)
|
||||
|
||||
if err != nil && !errors.Is(err, windows.ERROR_INSUFFICIENT_BUFFER) {
|
||||
return false, fmt.Errorf("get token information size: %w", err)
|
||||
}
|
||||
|
||||
buffer := make([]byte, returnLength)
|
||||
err = windows.GetTokenInformation(
|
||||
windows.Token(token),
|
||||
windows.TokenPrivileges,
|
||||
&buffer[0],
|
||||
returnLength,
|
||||
&returnLength,
|
||||
)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("get token information: %w", err)
|
||||
}
|
||||
|
||||
privileges := (*windows.Tokenprivileges)(unsafe.Pointer(&buffer[0]))
|
||||
|
||||
// Check if the privilege is present and enabled
|
||||
for i := uint32(0); i < privileges.PrivilegeCount; i++ {
|
||||
privilege := (*windows.LUIDAndAttributes)(unsafe.Pointer(
|
||||
uintptr(unsafe.Pointer(&privileges.Privileges[0])) +
|
||||
uintptr(i)*unsafe.Sizeof(windows.LUIDAndAttributes{}),
|
||||
))
|
||||
if privilege.Luid == luid {
|
||||
return (privilege.Attributes & windows.SE_PRIVILEGE_ENABLED) != 0, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, nil
|
||||
}
|
||||
|
||||
// enablePrivilege enables a specific privilege for the current process token
|
||||
// This is required because privileges like SeAssignPrimaryTokenPrivilege are present
|
||||
// but disabled by default, even for the SYSTEM account
|
||||
func enablePrivilege(token windows.Handle, privilegeName string) error {
|
||||
var luid windows.LUID
|
||||
if err := windows.LookupPrivilegeValue(nil, windows.StringToUTF16Ptr(privilegeName), &luid); err != nil {
|
||||
return fmt.Errorf("lookup privilege value for %s: %w", privilegeName, err)
|
||||
}
|
||||
|
||||
privileges := windows.Tokenprivileges{
|
||||
PrivilegeCount: 1,
|
||||
Privileges: [1]windows.LUIDAndAttributes{
|
||||
{
|
||||
Luid: luid,
|
||||
Attributes: windows.SE_PRIVILEGE_ENABLED,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
err := windows.AdjustTokenPrivileges(
|
||||
windows.Token(token),
|
||||
false,
|
||||
&privileges,
|
||||
0,
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("adjust token privileges for %s: %w", privilegeName, err)
|
||||
}
|
||||
|
||||
hasPriv, err := hasPrivilege(token, privilegeName)
|
||||
if err != nil {
|
||||
return fmt.Errorf("verify privilege %s after enabling: %w", privilegeName, err)
|
||||
}
|
||||
if !hasPriv {
|
||||
return fmt.Errorf("privilege %s could not be enabled (may not be granted to account)", privilegeName)
|
||||
}
|
||||
|
||||
log.Debugf("Successfully enabled privilege %s for current process", privilegeName)
|
||||
return nil
|
||||
}
|
||||
|
||||
// enableUserSwitching enables required privileges for Windows user switching
|
||||
func enableUserSwitching() error {
|
||||
process := windows.CurrentProcess()
|
||||
|
||||
var token windows.Token
|
||||
err := windows.OpenProcessToken(
|
||||
process,
|
||||
windows.TOKEN_ADJUST_PRIVILEGES|windows.TOKEN_QUERY,
|
||||
&token,
|
||||
)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open process token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
log.Debugf("Failed to close process token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := enablePrivilege(windows.Handle(token), "SeAssignPrimaryTokenPrivilege"); err != nil {
|
||||
return fmt.Errorf("enable SeAssignPrimaryTokenPrivilege: %w", err)
|
||||
}
|
||||
log.Infof("Windows user switching privileges enabled successfully")
|
||||
return nil
|
||||
}
|
||||
487
client/ssh/server/winpty/conpty.go
Normal file
487
client/ssh/server/winpty/conpty.go
Normal file
@@ -0,0 +1,487 @@
|
||||
//go:build windows
|
||||
|
||||
package winpty
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"unsafe"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrEmptyEnvironment = errors.New("empty environment")
|
||||
)
|
||||
|
||||
const (
|
||||
extendedStartupInfoPresent = 0x00080000
|
||||
createUnicodeEnvironment = 0x00000400
|
||||
procThreadAttributePseudoConsole = 0x00020016
|
||||
|
||||
PowerShellCommandFlag = "-Command"
|
||||
|
||||
errCloseInputRead = "close input read handle: %v"
|
||||
errCloseConPtyCleanup = "close ConPty handle during cleanup"
|
||||
)
|
||||
|
||||
// PtyConfig holds configuration for Pty execution.
|
||||
type PtyConfig struct {
|
||||
Shell string
|
||||
Command string
|
||||
Width int
|
||||
Height int
|
||||
WorkingDir string
|
||||
}
|
||||
|
||||
// UserConfig holds user execution configuration.
|
||||
type UserConfig struct {
|
||||
Token windows.Handle
|
||||
Environment []string
|
||||
}
|
||||
|
||||
var (
|
||||
kernel32 = windows.NewLazySystemDLL("kernel32.dll")
|
||||
procClosePseudoConsole = kernel32.NewProc("ClosePseudoConsole")
|
||||
procInitializeProcThreadAttributeList = kernel32.NewProc("InitializeProcThreadAttributeList")
|
||||
procUpdateProcThreadAttribute = kernel32.NewProc("UpdateProcThreadAttribute")
|
||||
procDeleteProcThreadAttributeList = kernel32.NewProc("DeleteProcThreadAttributeList")
|
||||
)
|
||||
|
||||
// ExecutePtyWithUserToken executes a command with ConPty using user token.
|
||||
func ExecutePtyWithUserToken(ctx context.Context, session ssh.Session, ptyConfig PtyConfig, userConfig UserConfig) error {
|
||||
args := buildShellArgs(ptyConfig.Shell, ptyConfig.Command)
|
||||
commandLine := buildCommandLine(args)
|
||||
|
||||
config := ExecutionConfig{
|
||||
Pty: ptyConfig,
|
||||
User: userConfig,
|
||||
Session: session,
|
||||
Context: ctx,
|
||||
}
|
||||
|
||||
return executeConPtyWithConfig(commandLine, config)
|
||||
}
|
||||
|
||||
// ExecutionConfig holds all execution configuration.
|
||||
type ExecutionConfig struct {
|
||||
Pty PtyConfig
|
||||
User UserConfig
|
||||
Session ssh.Session
|
||||
Context context.Context
|
||||
}
|
||||
|
||||
// executeConPtyWithConfig creates ConPty and executes process with configuration.
|
||||
func executeConPtyWithConfig(commandLine string, config ExecutionConfig) error {
|
||||
ctx := config.Context
|
||||
session := config.Session
|
||||
width := config.Pty.Width
|
||||
height := config.Pty.Height
|
||||
userToken := config.User.Token
|
||||
userEnv := config.User.Environment
|
||||
workingDir := config.Pty.WorkingDir
|
||||
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ConPty pipes: %w", err)
|
||||
}
|
||||
|
||||
hPty, err := createConPty(width, height, inputRead, outputWrite)
|
||||
if err != nil {
|
||||
return fmt.Errorf("create ConPty: %w", err)
|
||||
}
|
||||
|
||||
primaryToken, err := duplicateToPrimaryToken(userToken)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("duplicate to primary token: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(primaryToken); err != nil {
|
||||
log.Debugf("close primary token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
siEx, err := setupConPtyStartupInfo(hPty)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("setup startup info: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
_, _, _ = procDeleteProcThreadAttributeList.Call(uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)))
|
||||
}()
|
||||
|
||||
pi, err := createConPtyProcess(commandLine, primaryToken, userEnv, workingDir, siEx)
|
||||
if err != nil {
|
||||
if closeErr, _, _ := procClosePseudoConsole.Call(uintptr(hPty)); closeErr == 0 {
|
||||
log.Debugf(errCloseConPtyCleanup)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
return fmt.Errorf("create process as user with ConPty: %w", err)
|
||||
}
|
||||
defer closeProcessInfo(pi)
|
||||
|
||||
if err := windows.CloseHandle(inputRead); err != nil {
|
||||
log.Debugf(errCloseInputRead, err)
|
||||
}
|
||||
if err := windows.CloseHandle(outputWrite); err != nil {
|
||||
log.Debugf("close output write handle: %v", err)
|
||||
}
|
||||
|
||||
return bridgeConPtyIO(ctx, hPty, inputWrite, outputRead, session, session, session, pi.Process)
|
||||
}
|
||||
|
||||
// createConPtyPipes creates input/output pipes for ConPty.
|
||||
func createConPtyPipes() (inputRead, inputWrite, outputRead, outputWrite windows.Handle, err error) {
|
||||
if err := windows.CreatePipe(&inputRead, &inputWrite, nil, 0); err != nil {
|
||||
return 0, 0, 0, 0, fmt.Errorf("create input pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := windows.CreatePipe(&outputRead, &outputWrite, nil, 0); err != nil {
|
||||
if closeErr := windows.CloseHandle(inputRead); closeErr != nil {
|
||||
log.Debugf(errCloseInputRead, closeErr)
|
||||
}
|
||||
if closeErr := windows.CloseHandle(inputWrite); closeErr != nil {
|
||||
log.Debugf("close input write handle: %v", closeErr)
|
||||
}
|
||||
return 0, 0, 0, 0, fmt.Errorf("create output pipe: %w", err)
|
||||
}
|
||||
|
||||
return inputRead, inputWrite, outputRead, outputWrite, nil
|
||||
}
|
||||
|
||||
// createConPty creates a Windows ConPty with the specified size and pipe handles.
|
||||
func createConPty(width, height int, inputRead, outputWrite windows.Handle) (windows.Handle, error) {
|
||||
size := windows.Coord{X: int16(width), Y: int16(height)}
|
||||
|
||||
var hPty windows.Handle
|
||||
if err := windows.CreatePseudoConsole(size, inputRead, outputWrite, 0, &hPty); err != nil {
|
||||
return 0, fmt.Errorf("CreatePseudoConsole: %w", err)
|
||||
}
|
||||
|
||||
return hPty, nil
|
||||
}
|
||||
|
||||
// setupConPtyStartupInfo prepares the STARTUPINFOEX with ConPty attributes.
|
||||
func setupConPtyStartupInfo(hPty windows.Handle) (*windows.StartupInfoEx, error) {
|
||||
var siEx windows.StartupInfoEx
|
||||
siEx.StartupInfo.Cb = uint32(unsafe.Sizeof(siEx))
|
||||
|
||||
var attrListSize uintptr
|
||||
ret, _, _ := procInitializeProcThreadAttributeList.Call(0, 1, 0, uintptr(unsafe.Pointer(&attrListSize)))
|
||||
if ret == 0 && attrListSize == 0 {
|
||||
return nil, fmt.Errorf("get attribute list size")
|
||||
}
|
||||
|
||||
attrListBytes := make([]byte, attrListSize)
|
||||
siEx.ProcThreadAttributeList = (*windows.ProcThreadAttributeList)(unsafe.Pointer(&attrListBytes[0]))
|
||||
|
||||
ret, _, err := procInitializeProcThreadAttributeList.Call(
|
||||
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
|
||||
1,
|
||||
0,
|
||||
uintptr(unsafe.Pointer(&attrListSize)),
|
||||
)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("initialize attribute list: %w", err)
|
||||
}
|
||||
|
||||
ret, _, err = procUpdateProcThreadAttribute.Call(
|
||||
uintptr(unsafe.Pointer(siEx.ProcThreadAttributeList)),
|
||||
0,
|
||||
procThreadAttributePseudoConsole,
|
||||
uintptr(hPty),
|
||||
unsafe.Sizeof(hPty),
|
||||
0,
|
||||
0,
|
||||
)
|
||||
if ret == 0 {
|
||||
return nil, fmt.Errorf("update thread attribute: %w", err)
|
||||
}
|
||||
|
||||
return &siEx, nil
|
||||
}
|
||||
|
||||
// createConPtyProcess creates the actual process with ConPty.
|
||||
func createConPtyProcess(commandLine string, userToken windows.Handle, userEnv []string, workingDir string, siEx *windows.StartupInfoEx) (*windows.ProcessInformation, error) {
|
||||
var pi windows.ProcessInformation
|
||||
creationFlags := uint32(extendedStartupInfoPresent | createUnicodeEnvironment)
|
||||
|
||||
commandLinePtr, err := windows.UTF16PtrFromString(commandLine)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert command line to UTF16: %w", err)
|
||||
}
|
||||
|
||||
envPtr, err := convertEnvironmentToUTF16(userEnv)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var workingDirPtr *uint16
|
||||
if workingDir != "" {
|
||||
workingDirPtr, err = windows.UTF16PtrFromString(workingDir)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("convert working directory to UTF16: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
siEx.StartupInfo.Flags |= windows.STARTF_USESTDHANDLES
|
||||
siEx.StartupInfo.StdInput = windows.Handle(0)
|
||||
siEx.StartupInfo.StdOutput = windows.Handle(0)
|
||||
siEx.StartupInfo.StdErr = siEx.StartupInfo.StdOutput
|
||||
|
||||
if userToken != windows.InvalidHandle {
|
||||
err = windows.CreateProcessAsUser(
|
||||
windows.Token(userToken),
|
||||
nil,
|
||||
commandLinePtr,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
creationFlags,
|
||||
envPtr,
|
||||
workingDirPtr,
|
||||
&siEx.StartupInfo,
|
||||
&pi,
|
||||
)
|
||||
} else {
|
||||
err = windows.CreateProcess(
|
||||
nil,
|
||||
commandLinePtr,
|
||||
nil,
|
||||
nil,
|
||||
true,
|
||||
creationFlags,
|
||||
envPtr,
|
||||
workingDirPtr,
|
||||
&siEx.StartupInfo,
|
||||
&pi,
|
||||
)
|
||||
}
|
||||
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create process: %w", err)
|
||||
}
|
||||
|
||||
return &pi, nil
|
||||
}
|
||||
|
||||
// convertEnvironmentToUTF16 converts environment variables to Windows UTF16 format.
|
||||
func convertEnvironmentToUTF16(userEnv []string) (*uint16, error) {
|
||||
if len(userEnv) == 0 {
|
||||
// Return nil pointer for empty environment - Windows API will inherit parent environment
|
||||
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
|
||||
}
|
||||
|
||||
var envUTF16 []uint16
|
||||
for _, envVar := range userEnv {
|
||||
if envVar != "" {
|
||||
utf16Str, err := windows.UTF16FromString(envVar)
|
||||
if err != nil {
|
||||
log.Debugf("skipping invalid environment variable: %s (error: %v)", envVar, err)
|
||||
continue
|
||||
}
|
||||
envUTF16 = append(envUTF16, utf16Str[:len(utf16Str)-1]...)
|
||||
envUTF16 = append(envUTF16, 0)
|
||||
}
|
||||
}
|
||||
envUTF16 = append(envUTF16, 0)
|
||||
|
||||
if len(envUTF16) > 0 {
|
||||
return &envUTF16[0], nil
|
||||
}
|
||||
// Return nil pointer when no valid environment variables found
|
||||
return nil, nil //nolint:nilnil // Intentional nil,nil for empty environment
|
||||
}
|
||||
|
||||
// duplicateToPrimaryToken converts an impersonation token to a primary token.
|
||||
func duplicateToPrimaryToken(token windows.Handle) (windows.Handle, error) {
|
||||
var primaryToken windows.Handle
|
||||
if err := windows.DuplicateTokenEx(
|
||||
windows.Token(token),
|
||||
windows.TOKEN_ALL_ACCESS,
|
||||
nil,
|
||||
windows.SecurityImpersonation,
|
||||
windows.TokenPrimary,
|
||||
(*windows.Token)(&primaryToken),
|
||||
); err != nil {
|
||||
return 0, fmt.Errorf("duplicate token: %w", err)
|
||||
}
|
||||
return primaryToken, nil
|
||||
}
|
||||
|
||||
// SessionExiter provides the Exit method for reporting process exit status.
|
||||
type SessionExiter interface {
|
||||
Exit(code int) error
|
||||
}
|
||||
|
||||
// bridgeConPtyIO handles I/O bridging between ConPty and readers/writers.
|
||||
func bridgeConPtyIO(ctx context.Context, hPty, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer, session SessionExiter, process windows.Handle) error {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
startIOBridging(ctx, &wg, inputWrite, outputRead, reader, writer)
|
||||
|
||||
processErr := waitForProcess(ctx, process)
|
||||
if processErr != nil {
|
||||
return processErr
|
||||
}
|
||||
|
||||
var exitCode uint32
|
||||
if err := windows.GetExitCodeProcess(process, &exitCode); err != nil {
|
||||
log.Debugf("get exit code: %v", err)
|
||||
} else {
|
||||
if err := session.Exit(int(exitCode)); err != nil {
|
||||
log.Debugf("report exit code: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// Clean up in the original order after process completes
|
||||
if err := reader.Close(); err != nil {
|
||||
log.Debugf("close reader: %v", err)
|
||||
}
|
||||
|
||||
ret, _, err := procClosePseudoConsole.Call(uintptr(hPty))
|
||||
if ret == 0 {
|
||||
log.Debugf("close ConPty handle: %v", err)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
if err := windows.CloseHandle(outputRead); err != nil {
|
||||
log.Debugf("close output read handle: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// startIOBridging starts the I/O bridging goroutines.
|
||||
func startIOBridging(ctx context.Context, wg *sync.WaitGroup, inputWrite, outputRead windows.Handle, reader io.ReadCloser, writer io.Writer) {
|
||||
wg.Add(2)
|
||||
|
||||
// Input: reader (SSH session) -> inputWrite (ConPty)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(inputWrite); err != nil {
|
||||
log.Debugf("close input write handle in goroutine: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := io.Copy(&windowsHandleWriter{handle: inputWrite}, reader); err != nil {
|
||||
log.Debugf("input copy ended with error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
// Output: outputRead (ConPty) -> writer (SSH session)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
if _, err := io.Copy(writer, &windowsHandleReader{handle: outputRead}); err != nil {
|
||||
log.Debugf("output copy ended with error: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
// waitForProcess waits for process completion with context cancellation.
|
||||
func waitForProcess(ctx context.Context, process windows.Handle) error {
|
||||
if _, err := windows.WaitForSingleObject(process, windows.INFINITE); err != nil {
|
||||
return fmt.Errorf("wait for process %d: %w", process, err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// buildShellArgs builds shell arguments for ConPty execution.
|
||||
func buildShellArgs(shell, command string) []string {
|
||||
if command != "" {
|
||||
return []string{shell, PowerShellCommandFlag, command}
|
||||
}
|
||||
return []string{shell}
|
||||
}
|
||||
|
||||
// buildCommandLine builds a Windows command line from arguments using proper escaping.
|
||||
func buildCommandLine(args []string) string {
|
||||
if len(args) == 0 {
|
||||
return ""
|
||||
}
|
||||
|
||||
var result strings.Builder
|
||||
for i, arg := range args {
|
||||
if i > 0 {
|
||||
result.WriteString(" ")
|
||||
}
|
||||
result.WriteString(syscall.EscapeArg(arg))
|
||||
}
|
||||
return result.String()
|
||||
}
|
||||
|
||||
// closeHandles closes multiple Windows handles.
|
||||
func closeHandles(handles ...windows.Handle) {
|
||||
for _, handle := range handles {
|
||||
if handle != windows.InvalidHandle {
|
||||
if err := windows.CloseHandle(handle); err != nil {
|
||||
log.Debugf("close handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// closeProcessInfo closes process and thread handles.
|
||||
func closeProcessInfo(pi *windows.ProcessInformation) {
|
||||
if pi != nil {
|
||||
if err := windows.CloseHandle(pi.Process); err != nil {
|
||||
log.Debugf("close process handle: %v", err)
|
||||
}
|
||||
if err := windows.CloseHandle(pi.Thread); err != nil {
|
||||
log.Debugf("close thread handle: %v", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// windowsHandleReader wraps a Windows handle for reading.
|
||||
type windowsHandleReader struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func (r *windowsHandleReader) Read(p []byte) (n int, err error) {
|
||||
var bytesRead uint32
|
||||
if err := windows.ReadFile(r.handle, p, &bytesRead, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(bytesRead), nil
|
||||
}
|
||||
|
||||
func (r *windowsHandleReader) Close() error {
|
||||
return windows.CloseHandle(r.handle)
|
||||
}
|
||||
|
||||
// windowsHandleWriter wraps a Windows handle for writing.
|
||||
type windowsHandleWriter struct {
|
||||
handle windows.Handle
|
||||
}
|
||||
|
||||
func (w *windowsHandleWriter) Write(p []byte) (n int, err error) {
|
||||
var bytesWritten uint32
|
||||
if err := windows.WriteFile(w.handle, p, &bytesWritten, nil); err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return int(bytesWritten), nil
|
||||
}
|
||||
|
||||
func (w *windowsHandleWriter) Close() error {
|
||||
return windows.CloseHandle(w.handle)
|
||||
}
|
||||
290
client/ssh/server/winpty/conpty_test.go
Normal file
290
client/ssh/server/winpty/conpty_test.go
Normal file
@@ -0,0 +1,290 @@
|
||||
//go:build windows
|
||||
|
||||
package winpty
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/windows"
|
||||
)
|
||||
|
||||
func TestBuildShellArgs(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
shell string
|
||||
command string
|
||||
expected []string
|
||||
}{
|
||||
{
|
||||
name: "Shell with command",
|
||||
shell: "powershell.exe",
|
||||
command: "Get-Process",
|
||||
expected: []string{"powershell.exe", "-Command", "Get-Process"},
|
||||
},
|
||||
{
|
||||
name: "CMD with command",
|
||||
shell: "cmd.exe",
|
||||
command: "dir",
|
||||
expected: []string{"cmd.exe", "-Command", "dir"},
|
||||
},
|
||||
{
|
||||
name: "Shell interactive",
|
||||
shell: "powershell.exe",
|
||||
command: "",
|
||||
expected: []string{"powershell.exe"},
|
||||
},
|
||||
{
|
||||
name: "CMD interactive",
|
||||
shell: "cmd.exe",
|
||||
command: "",
|
||||
expected: []string{"cmd.exe"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildShellArgs(tt.shell, tt.command)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCommandLine(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
args []string
|
||||
expected string
|
||||
}{
|
||||
{
|
||||
name: "Simple args",
|
||||
args: []string{"cmd.exe", "/c", "echo"},
|
||||
expected: "cmd.exe /c echo",
|
||||
},
|
||||
{
|
||||
name: "Args with spaces",
|
||||
args: []string{"Program Files\\app.exe", "arg with spaces"},
|
||||
expected: `"Program Files\app.exe" "arg with spaces"`,
|
||||
},
|
||||
{
|
||||
name: "Args with quotes",
|
||||
args: []string{"cmd.exe", "/c", `echo "hello world"`},
|
||||
expected: `cmd.exe /c "echo \"hello world\""`,
|
||||
},
|
||||
{
|
||||
name: "PowerShell calling PowerShell",
|
||||
args: []string{"powershell.exe", "-Command", `powershell.exe -Command "Get-Process | Where-Object {$_.Name -eq 'notepad'}"`},
|
||||
expected: `powershell.exe -Command "powershell.exe -Command \"Get-Process | Where-Object {$_.Name -eq 'notepad'}\""`,
|
||||
},
|
||||
{
|
||||
name: "Complex nested quotes",
|
||||
args: []string{"cmd.exe", "/c", `echo "He said \"Hello\" to me"`},
|
||||
expected: `cmd.exe /c "echo \"He said \\\"Hello\\\" to me\""`,
|
||||
},
|
||||
{
|
||||
name: "Path with spaces and args",
|
||||
args: []string{`C:\Program Files\MyApp\app.exe`, "--config", `C:\My Config\settings.json`},
|
||||
expected: `"C:\Program Files\MyApp\app.exe" --config "C:\My Config\settings.json"`,
|
||||
},
|
||||
{
|
||||
name: "Empty argument",
|
||||
args: []string{"cmd.exe", "/c", "echo", ""},
|
||||
expected: `cmd.exe /c echo ""`,
|
||||
},
|
||||
{
|
||||
name: "Argument with backslashes",
|
||||
args: []string{"robocopy", `C:\Source\`, `C:\Dest\`, "/E"},
|
||||
expected: `robocopy C:\Source\ C:\Dest\ /E`,
|
||||
},
|
||||
{
|
||||
name: "Empty args",
|
||||
args: []string{},
|
||||
expected: "",
|
||||
},
|
||||
{
|
||||
name: "Single arg with space",
|
||||
args: []string{"path with spaces"},
|
||||
expected: `"path with spaces"`,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result := buildCommandLine(tt.args)
|
||||
assert.Equal(t, tt.expected, result)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestCreateConPtyPipes(t *testing.T) {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
require.NoError(t, err, "Should create ConPty pipes successfully")
|
||||
|
||||
// Verify all handles are valid
|
||||
assert.NotEqual(t, windows.InvalidHandle, inputRead, "Input read handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, inputWrite, "Input write handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, outputRead, "Output read handle should be valid")
|
||||
assert.NotEqual(t, windows.InvalidHandle, outputWrite, "Output write handle should be valid")
|
||||
|
||||
// Clean up handles
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
}
|
||||
|
||||
func TestCreateConPty(t *testing.T) {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
require.NoError(t, err, "Should create ConPty pipes successfully")
|
||||
defer closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
|
||||
hPty, err := createConPty(80, 24, inputRead, outputWrite)
|
||||
require.NoError(t, err, "Should create ConPty successfully")
|
||||
assert.NotEqual(t, windows.InvalidHandle, hPty, "ConPty handle should be valid")
|
||||
|
||||
// Clean up ConPty
|
||||
ret, _, _ := procClosePseudoConsole.Call(uintptr(hPty))
|
||||
assert.NotEqual(t, uintptr(0), ret, "Should close ConPty successfully")
|
||||
}
|
||||
|
||||
func TestConvertEnvironmentToUTF16(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
userEnv []string
|
||||
hasError bool
|
||||
}{
|
||||
{
|
||||
name: "Valid environment variables",
|
||||
userEnv: []string{"PATH=C:\\Windows", "USER=testuser", "HOME=C:\\Users\\testuser"},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Empty environment",
|
||||
userEnv: []string{},
|
||||
hasError: false,
|
||||
},
|
||||
{
|
||||
name: "Environment with empty strings",
|
||||
userEnv: []string{"PATH=C:\\Windows", "", "USER=testuser"},
|
||||
hasError: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
result, err := convertEnvironmentToUTF16(tt.userEnv)
|
||||
if tt.hasError {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
if len(tt.userEnv) == 0 {
|
||||
assert.Nil(t, result, "Empty environment should return nil")
|
||||
} else {
|
||||
assert.NotNil(t, result, "Non-empty environment should return valid pointer")
|
||||
}
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestDuplicateToPrimaryToken(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping token tests in short mode")
|
||||
}
|
||||
|
||||
// Get current process token for testing
|
||||
var token windows.Token
|
||||
err := windows.OpenProcessToken(windows.CurrentProcess(), windows.TOKEN_ALL_ACCESS, &token)
|
||||
require.NoError(t, err, "Should open current process token")
|
||||
defer func() {
|
||||
if err := windows.CloseHandle(windows.Handle(token)); err != nil {
|
||||
t.Logf("Failed to close token: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
primaryToken, err := duplicateToPrimaryToken(windows.Handle(token))
|
||||
require.NoError(t, err, "Should duplicate token to primary")
|
||||
assert.NotEqual(t, windows.InvalidHandle, primaryToken, "Primary token should be valid")
|
||||
|
||||
// Clean up
|
||||
err = windows.CloseHandle(primaryToken)
|
||||
assert.NoError(t, err, "Should close primary token")
|
||||
}
|
||||
|
||||
func TestWindowsHandleReader(t *testing.T) {
|
||||
// Create a pipe for testing
|
||||
var readHandle, writeHandle windows.Handle
|
||||
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
|
||||
require.NoError(t, err, "Should create pipe for testing")
|
||||
defer closeHandles(readHandle, writeHandle)
|
||||
|
||||
// Write test data
|
||||
testData := []byte("Hello, Windows Handle Reader!")
|
||||
var bytesWritten uint32
|
||||
err = windows.WriteFile(writeHandle, testData, &bytesWritten, nil)
|
||||
require.NoError(t, err, "Should write test data")
|
||||
require.Equal(t, uint32(len(testData)), bytesWritten, "Should write all test data")
|
||||
|
||||
// Close write handle to signal EOF
|
||||
if err := windows.CloseHandle(writeHandle); err != nil {
|
||||
t.Fatalf("Should close write handle: %v", err)
|
||||
}
|
||||
writeHandle = windows.InvalidHandle
|
||||
|
||||
// Test reading
|
||||
reader := &windowsHandleReader{handle: readHandle}
|
||||
buffer := make([]byte, len(testData))
|
||||
n, err := reader.Read(buffer)
|
||||
require.NoError(t, err, "Should read from handle")
|
||||
assert.Equal(t, len(testData), n, "Should read expected number of bytes")
|
||||
assert.Equal(t, testData, buffer, "Should read expected data")
|
||||
}
|
||||
|
||||
func TestWindowsHandleWriter(t *testing.T) {
|
||||
// Create a pipe for testing
|
||||
var readHandle, writeHandle windows.Handle
|
||||
err := windows.CreatePipe(&readHandle, &writeHandle, nil, 0)
|
||||
require.NoError(t, err, "Should create pipe for testing")
|
||||
defer closeHandles(readHandle, writeHandle)
|
||||
|
||||
// Test writing
|
||||
testData := []byte("Hello, Windows Handle Writer!")
|
||||
writer := &windowsHandleWriter{handle: writeHandle}
|
||||
n, err := writer.Write(testData)
|
||||
require.NoError(t, err, "Should write to handle")
|
||||
assert.Equal(t, len(testData), n, "Should write expected number of bytes")
|
||||
|
||||
// Close write handle
|
||||
if err := windows.CloseHandle(writeHandle); err != nil {
|
||||
t.Fatalf("Should close write handle: %v", err)
|
||||
}
|
||||
|
||||
// Verify data was written by reading it back
|
||||
buffer := make([]byte, len(testData))
|
||||
var bytesRead uint32
|
||||
err = windows.ReadFile(readHandle, buffer, &bytesRead, nil)
|
||||
require.NoError(t, err, "Should read back written data")
|
||||
assert.Equal(t, uint32(len(testData)), bytesRead, "Should read back expected number of bytes")
|
||||
assert.Equal(t, testData, buffer, "Should read back expected data")
|
||||
}
|
||||
|
||||
// BenchmarkConPtyCreation benchmarks ConPty creation performance
|
||||
func BenchmarkConPtyCreation(b *testing.B) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
inputRead, inputWrite, outputRead, outputWrite, err := createConPtyPipes()
|
||||
if err != nil {
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
hPty, err := createConPty(80, 24, inputRead, outputWrite)
|
||||
if err != nil {
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
b.Fatal(err)
|
||||
}
|
||||
|
||||
// Clean up
|
||||
if ret, _, err := procClosePseudoConsole.Call(uintptr(hPty)); ret == 0 {
|
||||
log.Debugf("ClosePseudoConsole failed: %v", err)
|
||||
}
|
||||
closeHandles(inputRead, inputWrite, outputRead, outputWrite)
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user