This commit is contained in:
Viktor Liu
2025-07-02 20:10:59 +02:00
parent 279b77dee0
commit 4bbca28eb6
19 changed files with 71 additions and 417 deletions

View File

@@ -36,7 +36,7 @@ func (s *Server) handleCommand(logger *log.Entry, session ssh.Session, privilege
}
errorMsg += "\n"
if _, writeErr := fmt.Fprintf(session.Stderr(), errorMsg); writeErr != nil {
if _, writeErr := fmt.Fprint(session.Stderr(), errorMsg); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if err := session.Exit(1); err != nil {
@@ -150,34 +150,6 @@ func (s *Server) handleCommandIO(logger *log.Entry, stdinPipe io.WriteCloser, se
}
}
// waitForCommandCompletion waits for command completion and handles exit codes
func (s *Server) waitForCommandCompletion(sessionKey SessionKey, session ssh.Session, execCmd *exec.Cmd) bool {
logger := log.WithField("session", sessionKey)
if err := execCmd.Wait(); err != nil {
logger.Debugf("command execution failed: %v", err)
var exitError *exec.ExitError
if errors.As(err, &exitError) {
if err := session.Exit(exitError.ExitCode()); err != nil {
logger.Debugf(errExitSession, err)
}
} else {
if _, writeErr := fmt.Fprintf(session.Stderr(), "failed to execute command: %v\n", err); writeErr != nil {
logger.Debugf(errWriteSession, writeErr)
}
if err := session.Exit(1); err != nil {
logger.Debugf(errExitSession, err)
}
}
return false
}
if err := session.Exit(0); err != nil {
logger.Debugf(errExitSession, err)
}
return true
}
// createPtyCommandWithPrivileges creates the exec.Cmd for Pty execution respecting privilege check results
func (s *Server) createPtyCommandWithPrivileges(cmd []string, privilegeResult PrivilegeCheckResult, ptyReq ssh.Pty, session ssh.Session) (*exec.Cmd, error) {
localUser := privilegeResult.User
@@ -246,23 +218,6 @@ func (s *Server) waitForCommandCleanup(logger *log.Entry, session ssh.Session, e
}
}
// handleCommandSessionCancellation handles command session cancellation
func (s *Server) handleCommandSessionCancellation(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd, done <-chan error) {
logger.Debugf("session cancelled, terminating command")
s.killProcessGroup(execCmd)
select {
case err := <-done:
logger.Debugf("command terminated after session cancellation: %v", err)
case <-time.After(5 * time.Second):
logger.Warnf("command did not terminate within 5 seconds after session cancellation")
}
if err := session.Exit(130); err != nil {
logger.Debugf(errExitSession, err)
}
}
// handleCommandCompletion handles command completion
func (s *Server) handleCommandCompletion(logger *log.Entry, session ssh.Session, err error) bool {
if err != nil {

View File

@@ -15,6 +15,7 @@ import (
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
@@ -88,28 +89,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username
"echo", "hello_world")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("SSH command failed: %v", err)
t.Logf("Output: %s", string(output))
return
}
assert.Contains(t, string(output), "hello_world", "SSH command should execute successfully")
}
// testSSHCommandExecution tests basic command execution with system SSH client.
func testSSHCommandExecution(t *testing.T, host, port, keyFile string) {
cmd := exec.Command("ssh",
"-i", keyFile,
"-p", port,
"-o", "StrictHostKeyChecking=no",
"-o", "UserKnownHostsFile=/dev/null",
"-o", "ConnectTimeout=5",
fmt.Sprintf("test-user@%s", host),
"echo", "hello_world")
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("SSH command failed: %v", err)
t.Logf("Output: %s", string(output))
@@ -269,7 +249,9 @@ func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
_, err = conn.Write([]byte(request))
require.NoError(t, err)
conn.SetReadDeadline(time.Now().Add(3 * time.Second))
if err := conn.SetReadDeadline(time.Now().Add(3 * time.Second)); err != nil {
log.Debugf("failed to set read deadline: %v", err)
}
response := make([]byte, len(expectedResponse))
n, err := io.ReadFull(conn, response)
if err != nil {
@@ -305,16 +287,16 @@ func generateOpenSSHKey() ([]byte, []byte, error) {
// Remove the temp file so ssh-keygen can create it
if err := os.Remove(keyPath); err != nil {
// Ignore if file doesn't exist, we just need it gone
t.Logf("failed to remove key file: %v", err)
}
// Clean up temp files
defer func() {
if err := os.Remove(keyPath); err != nil {
// Ignore cleanup errors but could log them in debug mode
t.Logf("failed to cleanup key file: %v", err)
}
if err := os.Remove(keyPath + ".pub"); err != nil {
// Ignore cleanup errors but could log them in debug mode
t.Logf("failed to cleanup public key file: %v", err)
}
}()

View File

@@ -1,226 +0,0 @@
//go:build unix
package server
import (
"context"
"os"
"os/exec"
"os/user"
"runtime"
"strconv"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestPrivilegeDropper_ValidatePrivileges(t *testing.T) {
pd := NewPrivilegeDropper()
tests := []struct {
name string
uid uint32
gid uint32
wantErr bool
}{
{
name: "valid non-root user",
uid: 1000,
gid: 1000,
wantErr: false,
},
{
name: "root UID should be rejected",
uid: 0,
gid: 1000,
wantErr: true,
},
{
name: "root GID should be rejected",
uid: 1000,
gid: 0,
wantErr: true,
},
{
name: "both root should be rejected",
uid: 0,
gid: 0,
wantErr: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := pd.validatePrivileges(tt.uid, tt.gid)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}
func TestPrivilegeDropper_CreateExecutorCommand(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000, 1001},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "ls -la",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify the command is calling netbird ssh exec
assert.Contains(t, cmd.Args, "ssh")
assert.Contains(t, cmd.Args, "exec")
assert.Contains(t, cmd.Args, "--uid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--gid")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "--groups")
assert.Contains(t, cmd.Args, "1000")
assert.Contains(t, cmd.Args, "1001")
assert.Contains(t, cmd.Args, "--working-dir")
assert.Contains(t, cmd.Args, "/home/testuser")
assert.Contains(t, cmd.Args, "--shell")
assert.Contains(t, cmd.Args, "/bin/bash")
assert.Contains(t, cmd.Args, "--cmd")
assert.Contains(t, cmd.Args, "ls -la")
}
func TestPrivilegeDropper_CreateExecutorCommandInteractive(t *testing.T) {
pd := NewPrivilegeDropper()
config := ExecutorConfig{
UID: 1000,
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/home/testuser",
Shell: "/bin/bash",
Command: "",
}
cmd, err := pd.CreateExecutorCommand(context.Background(), config)
require.NoError(t, err)
require.NotNil(t, cmd)
// Verify no command mode (command is empty so no --cmd flag)
assert.NotContains(t, cmd.Args, "--cmd")
assert.NotContains(t, cmd.Args, "--interactive")
}
// TestPrivilegeDropper_ActualPrivilegeDrop tests actual privilege dropping
// This test requires root privileges and will be skipped if not running as root
func TestPrivilegeDropper_ActualPrivilegeDrop(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("Privilege dropping not supported on Windows")
}
if os.Geteuid() != 0 {
t.Skip("This test requires root privileges")
}
// Find a non-root user to test with
testUser, err := user.Lookup("nobody")
if err != nil {
// Try to find any non-root user
testUser, err = findNonRootUser()
if err != nil {
t.Skip("No suitable non-root user found for testing")
}
}
uid64, err := strconv.ParseUint(testUser.Uid, 10, 32)
require.NoError(t, err)
targetUID := uint32(uid64)
gid64, err := strconv.ParseUint(testUser.Gid, 10, 32)
require.NoError(t, err)
targetGID := uint32(gid64)
// Test in a child process to avoid affecting the test runner
if os.Getenv("TEST_PRIVILEGE_DROP") == "1" {
pd := NewPrivilegeDropper()
// This should succeed
err := pd.DropPrivileges(targetUID, targetGID, []uint32{targetGID})
require.NoError(t, err)
// Verify we are now running as the target user
currentUID := uint32(os.Geteuid())
currentGID := uint32(os.Getegid())
assert.Equal(t, targetUID, currentUID, "UID should match target")
assert.Equal(t, targetGID, currentGID, "GID should match target")
assert.NotEqual(t, uint32(0), currentUID, "Should not be running as root")
assert.NotEqual(t, uint32(0), currentGID, "Should not be running as root group")
return
}
// Fork a child process to test privilege dropping
cmd := os.Args[0]
args := []string{"-test.run=TestPrivilegeDropper_ActualPrivilegeDrop"}
env := append(os.Environ(), "TEST_PRIVILEGE_DROP=1")
execCmd := exec.Command(cmd, args...)
execCmd.Env = env
err = execCmd.Run()
require.NoError(t, err, "Child process should succeed")
}
// findNonRootUser finds any non-root user on the system for testing
func findNonRootUser() (*user.User, error) {
// Try common non-root users
commonUsers := []string{"nobody", "daemon", "bin", "sys", "sync", "games", "man", "lp", "mail", "news", "uucp", "proxy", "www-data", "backup", "list", "irc"}
for _, username := range commonUsers {
if u, err := user.Lookup(username); err == nil {
uid64, err := strconv.ParseUint(u.Uid, 10, 32)
if err != nil {
continue
}
if uid64 != 0 { // Not root
return u, nil
}
}
}
// If no common users found, create a minimal user info for testing
// This won't actually work for privilege dropping but allows the test structure
return &user.User{
Uid: "65534", // Standard nobody UID
Gid: "65534", // Standard nobody GID
Username: "nobody",
Name: "nobody",
HomeDir: "/nonexistent",
}, nil
}
func TestPrivilegeDropper_ExecuteWithPrivilegeDrop_Validation(t *testing.T) {
pd := NewPrivilegeDropper()
// Test validation of root privileges - this should be caught in CreateExecutorCommand
config := ExecutorConfig{
UID: 0, // Root UID should be rejected
GID: 1000,
Groups: []uint32{1000},
WorkingDir: "/tmp",
Shell: "/bin/sh",
Command: "echo test",
}
_, err := pd.CreateExecutorCommand(context.Background(), config)
assert.Error(t, err)
assert.Contains(t, err.Error(), "root user")
}

View File

@@ -376,23 +376,6 @@ func (s *Server) proxyForwardConnection(ctx ssh.Context, logger *log.Entry, conn
}
}
// registerConnectionCancel stores a cancel function for a connection
func (s *Server) registerConnectionCancel(key ConnectionKey, cancel context.CancelFunc) {
s.mu.Lock()
defer s.mu.Unlock()
if s.sessionCancels == nil {
s.sessionCancels = make(map[ConnectionKey]context.CancelFunc)
}
s.sessionCancels[key] = cancel
}
// unregisterConnectionCancel removes a connection's cancel function
func (s *Server) unregisterConnectionCancel(key ConnectionKey) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.sessionCancels, key)
}
// monitorSessionContext watches for session cancellation and closes connections
func (s *Server) monitorSessionContext(ctx context.Context, channel cryptossh.Channel, conn net.Conn, closed chan struct{}, closeOnce *bool, logger *log.Entry) {
<-ctx.Done()

View File

@@ -375,16 +375,6 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
return "unknown"
}
// cleanupConnectionPortForward removes port forward state from a connection
func (s *Server) cleanupConnectionPortForward(sshConn *cryptossh.ServerConn) {
s.mu.Lock()
defer s.mu.Unlock()
if state, exists := s.sshConnections[sshConn]; exists {
state.hasActivePortForward = false
}
}
// connectionValidator validates incoming connections based on source IP
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
s.mu.RLock()

View File

@@ -129,17 +129,17 @@ func (s *Server) buildUserLookupErrorMessage(err error) string {
switch {
case errors.As(err, &privilegedErr):
if privilegedErr.Username == "root" {
return fmt.Sprintf("root login is disabled on this SSH server\n")
return "root login is disabled on this SSH server\n"
}
return fmt.Sprintf("privileged user access is disabled on this SSH server\n")
return "privileged user access is disabled on this SSH server\n"
case errors.Is(err, ErrPrivilegeRequired):
return fmt.Sprintf("Windows user switching failed - NetBird must run with elevated privileges for user switching\n")
return "Windows user switching failed - NetBird must run with elevated privileges for user switching\n"
case errors.Is(err, ErrPrivilegedUserSwitch):
return fmt.Sprintf("Cannot switch to privileged user - current user lacks required privileges\n")
return "Cannot switch to privileged user - current user lacks required privileges\n"
default:
return fmt.Sprintf("User authentication failed\n")
return "User authentication failed\n"
}
}

View File

@@ -18,7 +18,7 @@ import (
const (
defaultUnixShell = "/bin/sh"
pwshExe = "pwsh.exe"
pwshExe = "pwsh.exe" // #nosec G101 - This is not a credential, just executable name
powershellExe = "powershell.exe"
)
@@ -104,7 +104,7 @@ func prepareUserEnv(user *user.User, shell string) []string {
fmt.Sprint("USER=" + user.Username),
fmt.Sprint("LOGNAME=" + user.Username),
fmt.Sprint("HOME=" + user.HomeDir),
fmt.Sprint("PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games"),
"PATH=/usr/local/bin:/usr/bin:/bin:/usr/local/games:/usr/games",
}
}

View File

@@ -85,7 +85,7 @@ func attachSocketFilter(listener net.Listener, wgIfIndex int) error {
fd := int(file.Fd())
_, _, errno := syscall.Syscall6(
syscall.SYS_SETSOCKOPT,
unix.SYS_SETSOCKOPT,
uintptr(fd),
uintptr(unix.SOL_SOCKET),
uintptr(unix.SO_ATTACH_FILTER),

View File

@@ -183,17 +183,6 @@ func isSameResolvedUser(user1, user2 *user.User) bool {
return user1.Uid == user2.Uid
}
// logPrivilegeCheckResult logs the final result of privilege checking
func (s *Server) logPrivilegeCheckResult(req PrivilegeCheckRequest, result PrivilegeCheckResult) {
if !result.Allowed {
log.Debugf("Privilege check denied for %s (user: %s, feature: %s): %v",
req.FeatureName, req.RequestedUsername, req.FeatureName, result.Error)
} else {
log.Debugf("Privilege check allowed for %s (user: %s, requires_switching: %v)",
req.FeatureName, req.RequestedUsername, result.RequiresUserSwitching)
}
}
// privilegeCheckContext holds all context needed for privilege checking
type privilegeCheckContext struct {
currentUser *user.User
@@ -389,7 +378,7 @@ func isWindowsPrivilegedSID(sid string) bool {
return false
}
// buildShellArgs builds shell arguments for executing commands.
// buildShellArgs builds shell arguments for executing commands
func buildShellArgs(shell, command string) []string {
if command != "" {
return []string{shell, "-Command", command}

View File

@@ -674,19 +674,20 @@ func TestCheckPrivileges_ActualPlatform(t *testing.T) {
FeatureName: "SSH login",
})
if actualOS == "windows" {
switch {
case actualOS == "windows":
// Windows should deny user switching
assert.False(t, result.Allowed, "Windows should deny user switching")
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")
assert.Contains(t, result.Error.Error(), "user switching not supported",
"Should indicate user switching not supported")
} else if !actualIsPrivileged {
case !actualIsPrivileged:
// Non-privileged Unix processes should fallback to current user
assert.True(t, result.Allowed, "Non-privileged Unix process should fallback to current user")
assert.False(t, result.RequiresUserSwitching, "Fallback means no switching actually happens")
assert.True(t, result.UsedFallback, "Should indicate fallback was used")
assert.NotNil(t, result.User, "Should return current user")
} else {
default:
// Privileged Unix processes should attempt user lookup
assert.False(t, result.Allowed, "Should fail due to nonexistent user")
assert.True(t, result.RequiresUserSwitching, "Should indicate switching is needed")