diff --git a/client/ssh/client/client_test.go b/client/ssh/client/client_test.go index 77be622ec..d141ac00a 100644 --- a/client/ssh/client/client_test.go +++ b/client/ssh/client/client_test.go @@ -7,12 +7,14 @@ import ( "io" "net" "os" + "os/exec" "os/user" "runtime" "strings" "testing" "time" + log "github.com/sirupsen/logrus" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" cryptossh "golang.org/x/crypto/ssh" @@ -21,6 +23,17 @@ import ( sshserver "github.com/netbirdio/netbird/client/ssh/server" ) +// TestMain handles package-level setup and cleanup +func TestMain(m *testing.M) { + // Run tests + code := m.Run() + + // Cleanup any created test users + cleanupTestUsers() + + os.Exit(code) +} + func TestSSHClient_DialWithKey(t *testing.T) { // Generate host key for server hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) @@ -417,7 +430,11 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) { go func() { err := client.LocalPortForward(ctx, localAddr, testServerAddr) if err != nil && !errors.Is(err, context.Canceled) { - t.Logf("Port forward error: %v", err) + if isWindowsPrivilegeError(err) { + t.Logf("Port forward failed due to Windows privilege restrictions: %v", err) + } else { + t.Logf("Port forward error: %v", err) + } } }() @@ -448,6 +465,23 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) { func getCurrentUsername() string { if runtime.GOOS == "windows" { if currentUser, err := user.Current(); err == nil { + // Check if this is a system account that can't authenticate + if isSystemAccount(currentUser.Username) { + // In CI environments, create a test user; otherwise try Administrator + if isCI() { + if testUser := getOrCreateTestUser(); testUser != "" { + return testUser + } + } else { + // Try Administrator first for local development + if _, err := user.Lookup("Administrator"); err == nil { + return "Administrator" + } + if testUser := getOrCreateTestUser(); testUser != "" { + return testUser + } + } + } // On Windows, return the full domain\username for proper authentication return currentUser.Username } @@ -463,3 +497,154 @@ func getCurrentUsername() string { return "test-user" } + +// isCI checks if we're running in a CI environment +func isCI() bool { + ciEnvVars := []string{ + "CI", "CONTINUOUS_INTEGRATION", "GITHUB_ACTIONS", + "GITLAB_CI", "JENKINS_URL", "BUILDKITE", "CIRCLECI", + } + + for _, envVar := range ciEnvVars { + if os.Getenv(envVar) != "" { + return true + } + } + return false +} + +// getOrCreateTestUser creates a test user on Windows if needed +func getOrCreateTestUser() string { + testUsername := "netbird-test-user" + + // Check if user already exists + if _, err := user.Lookup(testUsername); err == nil { + return testUsername + } + + // Try to create the user using PowerShell + if createWindowsTestUser(testUsername) { + // Register cleanup for the test user + registerTestUserCleanup(testUsername) + return testUsername + } + + return "" +} + +var createdTestUsers = make(map[string]bool) +var testUsersToCleanup []string + +// registerTestUserCleanup registers a test user for cleanup +func registerTestUserCleanup(username string) { + if !createdTestUsers[username] { + createdTestUsers[username] = true + testUsersToCleanup = append(testUsersToCleanup, username) + } +} + +// cleanupTestUsers removes all created test users +func cleanupTestUsers() { + for _, username := range testUsersToCleanup { + removeWindowsTestUser(username) + } + testUsersToCleanup = nil + createdTestUsers = make(map[string]bool) +} + +// removeWindowsTestUser removes a local user on Windows using PowerShell +func removeWindowsTestUser(username string) { + if runtime.GOOS != "windows" { + return + } + + // PowerShell command to remove a local user + psCmd := fmt.Sprintf(` + try { + Remove-LocalUser -Name "%s" -ErrorAction Stop + Write-Output "User removed successfully" + } catch { + if ($_.Exception.Message -like "*cannot be found*") { + Write-Output "User not found (already removed)" + } else { + Write-Error $_.Exception.Message + } + } + `, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output)) + } else { + log.Printf("Test user %s cleanup result: %s", username, string(output)) + } +} + +// createWindowsTestUser creates a local user on Windows using PowerShell +func createWindowsTestUser(username string) bool { + if runtime.GOOS != "windows" { + return false + } + + // PowerShell command to create a local user + psCmd := fmt.Sprintf(` + try { + $password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force + New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires + Add-LocalGroupMember -Group "Users" -Member "%s" + Write-Output "User created successfully" + } catch { + if ($_.Exception.Message -like "*already exists*") { + Write-Output "User already exists" + } else { + Write-Error $_.Exception.Message + exit 1 + } + } + `, username, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + log.Printf("Failed to create test user: %v, output: %s", err, string(output)) + return false + } + + log.Printf("Test user creation result: %s", string(output)) + return true +} + +// isSystemAccount checks if the user is a system account that can't authenticate +func isSystemAccount(username string) bool { + systemAccounts := []string{ + "system", + "NT AUTHORITY\\SYSTEM", + "NT AUTHORITY\\LOCAL SERVICE", + "NT AUTHORITY\\NETWORK SERVICE", + } + + for _, sysAccount := range systemAccounts { + if strings.EqualFold(username, sysAccount) { + return true + } + } + return false +} + +// isWindowsPrivilegeError checks if an error is related to Windows privilege restrictions +func isWindowsPrivilegeError(err error) bool { + if err == nil { + return false + } + + errStr := strings.ToLower(err.Error()) + return strings.Contains(errStr, "ntstatus=0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD + strings.Contains(errStr, "0xc0000041") || // STATUS_PRIVILEGE_NOT_HELD (LsaRegisterLogonProcess) + strings.Contains(errStr, "0xc0000062") || // STATUS_PRIVILEGE_NOT_HELD (LsaLogonUser) + strings.Contains(errStr, "privilege") || + strings.Contains(errStr, "access denied") || + strings.Contains(errStr, "user authentication failed") +} diff --git a/client/ssh/server/command_execution.go b/client/ssh/server/command_execution.go index 0035a5f74..57589f718 100644 --- a/client/ssh/server/command_execution.go +++ b/client/ssh/server/command_execution.go @@ -66,7 +66,6 @@ func (s *Server) createCommand(privilegeResult PrivilegeCheckResult, session ssh return cmd, nil } - // executeCommand executes the command and handles I/O and exit codes func (s *Server) executeCommand(logger *log.Entry, session ssh.Session, execCmd *exec.Cmd) bool { s.setupProcessGroup(execCmd) diff --git a/client/ssh/server/compatibility_test.go b/client/ssh/server/compatibility_test.go index 920da638d..7cc924215 100644 --- a/client/ssh/server/compatibility_test.go +++ b/client/ssh/server/compatibility_test.go @@ -23,6 +23,17 @@ import ( nbssh "github.com/netbirdio/netbird/client/ssh" ) +// TestMain handles package-level setup and cleanup +func TestMain(m *testing.M) { + // Run tests + code := m.Run() + + // Cleanup any created test users + cleanupTestUsers() + + os.Exit(code) +} + // TestSSHServerCompatibility tests that our SSH server is compatible with the system SSH client func TestSSHServerCompatibility(t *testing.T) { if testing.Short() { @@ -61,12 +72,11 @@ func TestSSHServerCompatibility(t *testing.T) { host, portStr, err := net.SplitHostPort(serverAddr) require.NoError(t, err) - // Get current user for SSH connection instead of hardcoded test-user - currentUser, err := user.Current() - require.NoError(t, err, "Should be able to get current user for compatibility test") + // Get appropriate user for SSH connection (handle system accounts) + username := getTestUsername(t) t.Run("basic command execution", func(t *testing.T) { - testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, currentUser.Username) + testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username) }) t.Run("interactive command", func(t *testing.T) { @@ -102,9 +112,8 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username // testSSHInteractiveCommand tests interactive shell session. func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) { - // Get current user for SSH connection - currentUser, err := user.Current() - require.NoError(t, err, "Should be able to get current user") + // Get appropriate user for SSH connection + username := getTestUsername(t) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) defer cancel() @@ -115,7 +124,7 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) { "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=5", - fmt.Sprintf("%s@%s", currentUser.Username, host)) + fmt.Sprintf("%s@%s", username, host)) stdin, err := cmd.StdinPipe() if err != nil { @@ -168,9 +177,8 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) { // testSSHPortForwarding tests port forwarding compatibility. func testSSHPortForwarding(t *testing.T, host, port, keyFile string) { - // Get current user for SSH connection - currentUser, err := user.Current() - require.NoError(t, err, "Should be able to get current user") + // Get appropriate user for SSH connection + username := getTestUsername(t) testServer, err := net.Listen("tcp", "127.0.0.1:0") require.NoError(t, err) @@ -222,7 +230,7 @@ func testSSHPortForwarding(t *testing.T, host, port, keyFile string) { "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=5", "-N", - fmt.Sprintf("%s@%s", currentUser.Username, host)) + fmt.Sprintf("%s@%s", username, host)) err = cmd.Start() if err != nil { @@ -455,9 +463,8 @@ func TestSSHServerFeatureCompatibility(t *testing.T) { // testCommandWithFlags tests that commands with flags work properly func testCommandWithFlags(t *testing.T, host, port, keyFile string) { - // Get current user for SSH connection - currentUser, err := user.Current() - require.NoError(t, err, "Should be able to get current user") + // Get appropriate user for SSH connection + username := getTestUsername(t) // Test ls with flags cmd := exec.Command("ssh", @@ -466,7 +473,7 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) { "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=5", - fmt.Sprintf("%s@%s", currentUser.Username, host), + fmt.Sprintf("%s@%s", username, host), "ls", "-la", "/tmp") output, err := cmd.CombinedOutput() @@ -483,9 +490,8 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) { // testEnvironmentVariables tests that environment is properly set up func testEnvironmentVariables(t *testing.T, host, port, keyFile string) { - // Get current user for SSH connection - currentUser, err := user.Current() - require.NoError(t, err, "Should be able to get current user") + // Get appropriate user for SSH connection + username := getTestUsername(t) cmd := exec.Command("ssh", "-i", keyFile, @@ -493,7 +499,7 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) { "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=5", - fmt.Sprintf("%s@%s", currentUser.Username, host), + fmt.Sprintf("%s@%s", username, host), "echo", "$HOME") output, err := cmd.CombinedOutput() @@ -511,9 +517,8 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) { // testExitCodes tests that exit codes are properly handled func testExitCodes(t *testing.T, host, port, keyFile string) { - // Get current user for SSH connection - currentUser, err := user.Current() - require.NoError(t, err, "Should be able to get current user") + // Get appropriate user for SSH connection + username := getTestUsername(t) // Test successful command (exit code 0) cmd := exec.Command("ssh", @@ -522,10 +527,10 @@ func testExitCodes(t *testing.T, host, port, keyFile string) { "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=5", - fmt.Sprintf("%s@%s", currentUser.Username, host), + fmt.Sprintf("%s@%s", username, host), "true") // always succeeds - err = cmd.Run() + err := cmd.Run() assert.NoError(t, err, "Command with exit code 0 should succeed") // Test failing command (exit code 1) @@ -535,7 +540,7 @@ func testExitCodes(t *testing.T, host, port, keyFile string) { "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=5", - fmt.Sprintf("%s@%s", currentUser.Username, host), + fmt.Sprintf("%s@%s", username, host), "false") // always fails err = cmd.Run() @@ -557,9 +562,8 @@ func TestSSHServerSecurityFeatures(t *testing.T) { t.Skip("SSH client not available on this system") } - // Get current user for SSH connection - currentUser, err := user.Current() - require.NoError(t, err, "Should be able to get current user") + // Get appropriate user for SSH connection + username := getTestUsername(t) // Set up SSH server with specific security settings hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) @@ -596,7 +600,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) { "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=5", "-o", "PasswordAuthentication=no", - fmt.Sprintf("%s@%s", currentUser.Username, host), + fmt.Sprintf("%s@%s", username, host), "echo", "auth_success") output, err := cmd.CombinedOutput() @@ -625,7 +629,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) { "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=5", "-o", "PasswordAuthentication=no", - fmt.Sprintf("%s@%s", currentUser.Username, host), + fmt.Sprintf("%s@%s", username, host), "echo", "should_not_work") err = cmd.Run() @@ -643,9 +647,8 @@ func TestCrossPlatformCompatibility(t *testing.T) { t.Skip("SSH client not available on this system") } - // Get current user for SSH connection - currentUser, err := user.Current() - require.NoError(t, err, "Should be able to get current user") + // Get appropriate user for SSH connection + username := getTestUsername(t) // Set up SSH server hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) @@ -689,7 +692,7 @@ func TestCrossPlatformCompatibility(t *testing.T) { "-o", "StrictHostKeyChecking=no", "-o", "UserKnownHostsFile=/dev/null", "-o", "ConnectTimeout=5", - fmt.Sprintf("%s@%s", currentUser.Username, host), + fmt.Sprintf("%s@%s", username, host), testCommand) output, err := cmd.CombinedOutput() @@ -703,3 +706,171 @@ func TestCrossPlatformCompatibility(t *testing.T) { t.Logf("Platform command output: %s", outputStr) assert.NotEmpty(t, outputStr, "Platform-specific command should produce output") } + +// getTestUsername returns an appropriate username for testing +func getTestUsername(t *testing.T) string { + if runtime.GOOS == "windows" { + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + + // Check if this is a system account that can't authenticate + if isSystemAccount(currentUser.Username) { + // In CI environments, create a test user; otherwise try Administrator + if isCI() { + if testUser := getOrCreateTestUser(t); testUser != "" { + return testUser + } + } else { + // Try Administrator first for local development + if _, err := user.Lookup("Administrator"); err == nil { + return "Administrator" + } + if testUser := getOrCreateTestUser(t); testUser != "" { + return testUser + } + } + } + return currentUser.Username + } + + currentUser, err := user.Current() + require.NoError(t, err, "Should be able to get current user") + return currentUser.Username +} + +// isCI checks if we're running in a CI environment +func isCI() bool { + ciEnvVars := []string{ + "CI", "CONTINUOUS_INTEGRATION", "GITHUB_ACTIONS", + "GITLAB_CI", "JENKINS_URL", "BUILDKITE", "CIRCLECI", + } + + for _, envVar := range ciEnvVars { + if os.Getenv(envVar) != "" { + return true + } + } + return false +} + +// isSystemAccount checks if the user is a system account that can't authenticate +func isSystemAccount(username string) bool { + systemAccounts := []string{ + "system", + "NT AUTHORITY\\SYSTEM", + "NT AUTHORITY\\LOCAL SERVICE", + "NT AUTHORITY\\NETWORK SERVICE", + } + + for _, sysAccount := range systemAccounts { + if strings.EqualFold(username, sysAccount) { + return true + } + } + return false +} + +var compatTestCreatedUsers = make(map[string]bool) +var compatTestUsersToCleanup []string + +// registerTestUserCleanup registers a test user for cleanup +func registerTestUserCleanup(username string) { + if !compatTestCreatedUsers[username] { + compatTestCreatedUsers[username] = true + compatTestUsersToCleanup = append(compatTestUsersToCleanup, username) + } +} + +// cleanupTestUsers removes all created test users +func cleanupTestUsers() { + for _, username := range compatTestUsersToCleanup { + removeWindowsTestUser(username) + } + compatTestUsersToCleanup = nil + compatTestCreatedUsers = make(map[string]bool) +} + +// getOrCreateTestUser creates a test user on Windows if needed +func getOrCreateTestUser(t *testing.T) string { + testUsername := "netbird-test-user" + + // Check if user already exists + if _, err := user.Lookup(testUsername); err == nil { + return testUsername + } + + // Try to create the user using PowerShell + if createWindowsTestUser(t, testUsername) { + // Register cleanup for the test user + registerTestUserCleanup(testUsername) + return testUsername + } + + return "" +} + +// removeWindowsTestUser removes a local user on Windows using PowerShell +func removeWindowsTestUser(username string) { + if runtime.GOOS != "windows" { + return + } + + // PowerShell command to remove a local user + psCmd := fmt.Sprintf(` + try { + Remove-LocalUser -Name "%s" -ErrorAction Stop + Write-Output "User removed successfully" + } catch { + if ($_.Exception.Message -like "*cannot be found*") { + Write-Output "User not found (already removed)" + } else { + Write-Error $_.Exception.Message + } + } + `, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output)) + } else { + log.Printf("Test user %s cleanup result: %s", username, string(output)) + } +} + +// createWindowsTestUser creates a local user on Windows using PowerShell +func createWindowsTestUser(t *testing.T, username string) bool { + if runtime.GOOS != "windows" { + return false + } + + // PowerShell command to create a local user + psCmd := fmt.Sprintf(` + try { + $password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force + New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires + Add-LocalGroupMember -Group "Users" -Member "%s" + Write-Output "User created successfully" + } catch { + if ($_.Exception.Message -like "*already exists*") { + Write-Output "User already exists" + } else { + Write-Error $_.Exception.Message + exit 1 + } + } + `, username, username) + + cmd := exec.Command("powershell", "-Command", psCmd) + output, err := cmd.CombinedOutput() + + if err != nil { + t.Logf("Failed to create test user: %v, output: %s", err, string(output)) + return false + } + + t.Logf("Test user creation result: %s", string(output)) + return true +} + diff --git a/client/ssh/server/server_config_test.go b/client/ssh/server/server_config_test.go index 1b10b1766..b61c9c84b 100644 --- a/client/ssh/server/server_config_test.go +++ b/client/ssh/server/server_config_test.go @@ -94,15 +94,24 @@ func TestServer_RootLoginRestriction(t *testing.T) { for _, tt := range tests { t.Run(tt.name, func(t *testing.T) { // Mock privileged environment to test root access controls + // Set up mock users based on platform + mockUsers := map[string]*user.User{ + "root": createTestUser("root", "0", "0", "/root"), + "testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"), + } + + // Add Windows-specific users for Administrator tests + if runtime.GOOS == "windows" { + mockUsers["Administrator"] = createTestUser("Administrator", "500", "544", "C:\\Users\\Administrator") + mockUsers["administrator"] = createTestUser("administrator", "500", "544", "C:\\Users\\administrator") + } + cleanup := setupTestDependencies( createTestUser("root", "0", "0", "/root"), // Running as root nil, runtime.GOOS, 0, // euid 0 (root) - map[string]*user.User{ - "root": createTestUser("root", "0", "0", "/root"), - "testuser": createTestUser("testuser", "1000", "1000", "/home/testuser"), - }, + mockUsers, nil, ) defer cleanup()