mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 17:26:40 +00:00
Add ssh authenatication with jwt (#4550)
This commit is contained in:
@@ -9,7 +9,6 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -21,15 +20,24 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
cleanupTestUsers()
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -50,13 +58,15 @@ func TestSSHServerCompatibility(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate OpenSSH-compatible keys for client
|
||||
clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey(t)
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKeyOpenSSH))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -73,7 +83,7 @@ func TestSSHServerCompatibility(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get appropriate user for SSH connection (handle system accounts)
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("basic command execution", func(t *testing.T) {
|
||||
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
|
||||
@@ -113,7 +123,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username
|
||||
// testSSHInteractiveCommand tests interactive shell session.
|
||||
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
@@ -178,7 +188,7 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||
// testSSHPortForwarding tests port forwarding compatibility.
|
||||
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
@@ -401,7 +411,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
t.Skip("Skipping SSH feature compatibility tests in short mode")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && isCI() {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
@@ -438,13 +448,13 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -468,7 +478,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
// testCommandWithFlags tests that commands with flags work properly
|
||||
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test ls with flags
|
||||
cmd := exec.Command("ssh",
|
||||
@@ -495,7 +505,7 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||
// testEnvironmentVariables tests that environment is properly set up
|
||||
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
@@ -522,7 +532,7 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||
// testExitCodes tests that exit codes are properly handled
|
||||
func testExitCodes(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test successful command (exit code 0)
|
||||
cmd := exec.Command("ssh",
|
||||
@@ -567,7 +577,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server with specific security settings
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
@@ -575,13 +585,13 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -652,7 +662,7 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
@@ -660,13 +670,13 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -710,171 +720,3 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
t.Logf("Platform command output: %s", outputStr)
|
||||
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
|
||||
}
|
||||
|
||||
// getTestUsername returns an appropriate username for testing
|
||||
func getTestUsername(t *testing.T) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Check if this is a system account that can't authenticate
|
||||
if isSystemAccount(currentUser.Username) {
|
||||
// In CI environments, create a test user; otherwise try Administrator
|
||||
if isCI() {
|
||||
if testUser := getOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
} else {
|
||||
// Try Administrator first for local development
|
||||
if _, err := user.Lookup("Administrator"); err == nil {
|
||||
return "Administrator"
|
||||
}
|
||||
if testUser := getOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
// isCI checks if we're running in a CI environment
|
||||
func isCI() bool {
|
||||
// Check standard CI environment variables
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
|
||||
hostname, err := os.Hostname()
|
||||
if err == nil && strings.HasPrefix(hostname, "runner") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isSystemAccount checks if the user is a system account that can't authenticate
|
||||
func isSystemAccount(username string) bool {
|
||||
systemAccounts := []string{
|
||||
"system",
|
||||
"NT AUTHORITY\\SYSTEM",
|
||||
"NT AUTHORITY\\LOCAL SERVICE",
|
||||
"NT AUTHORITY\\NETWORK SERVICE",
|
||||
}
|
||||
|
||||
for _, sysAccount := range systemAccounts {
|
||||
if strings.EqualFold(username, sysAccount) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var compatTestCreatedUsers = make(map[string]bool)
|
||||
var compatTestUsersToCleanup []string
|
||||
|
||||
// registerTestUserCleanup registers a test user for cleanup
|
||||
func registerTestUserCleanup(username string) {
|
||||
if !compatTestCreatedUsers[username] {
|
||||
compatTestCreatedUsers[username] = true
|
||||
compatTestUsersToCleanup = append(compatTestUsersToCleanup, username)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupTestUsers removes all created test users
|
||||
func cleanupTestUsers() {
|
||||
for _, username := range compatTestUsersToCleanup {
|
||||
removeWindowsTestUser(username)
|
||||
}
|
||||
compatTestUsersToCleanup = nil
|
||||
compatTestCreatedUsers = make(map[string]bool)
|
||||
}
|
||||
|
||||
// getOrCreateTestUser creates a test user on Windows if needed
|
||||
func getOrCreateTestUser(t *testing.T) string {
|
||||
testUsername := "netbird-test-user"
|
||||
|
||||
// Check if user already exists
|
||||
if _, err := user.Lookup(testUsername); err == nil {
|
||||
return testUsername
|
||||
}
|
||||
|
||||
// Try to create the user using PowerShell
|
||||
if createWindowsTestUser(t, testUsername) {
|
||||
// Register cleanup for the test user
|
||||
registerTestUserCleanup(testUsername)
|
||||
return testUsername
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// removeWindowsTestUser removes a local user on Windows using PowerShell
|
||||
func removeWindowsTestUser(username string) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
// PowerShell command to remove a local user
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
Remove-LocalUser -Name "%s" -ErrorAction Stop
|
||||
Write-Output "User removed successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*cannot be found*") {
|
||||
Write-Output "User not found (already removed)"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
}
|
||||
}
|
||||
`, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
|
||||
} else {
|
||||
log.Printf("Test user %s cleanup result: %s", username, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// createWindowsTestUser creates a local user on Windows using PowerShell
|
||||
func createWindowsTestUser(t *testing.T, username string) bool {
|
||||
if runtime.GOOS != "windows" {
|
||||
return false
|
||||
}
|
||||
|
||||
// PowerShell command to create a local user
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
|
||||
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
|
||||
Add-LocalGroupMember -Group "Users" -Member "%s"
|
||||
Write-Output "User created successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*already exists*") {
|
||||
Write-Output "User already exists"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
`, username, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
|
||||
return false
|
||||
}
|
||||
|
||||
t.Logf("Test user creation result: %s", string(output))
|
||||
return true
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user