Add ssh authenatication with jwt (#4550)

This commit is contained in:
Viktor Liu
2025-10-07 23:38:27 +02:00
committed by GitHub
parent 7e0bbaaa3c
commit d9efe4e944
50 changed files with 4429 additions and 2336 deletions

View File

@@ -9,7 +9,6 @@ import (
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"testing"
@@ -21,15 +20,24 @@ import (
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
cleanupTestUsers()
testutil.CleanupTestUsers()
os.Exit(code)
}
@@ -50,13 +58,15 @@ func TestSSHServerCompatibility(t *testing.T) {
require.NoError(t, err)
// Generate OpenSSH-compatible keys for client
clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey(t)
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKeyOpenSSH))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -73,7 +83,7 @@ func TestSSHServerCompatibility(t *testing.T) {
require.NoError(t, err)
// Get appropriate user for SSH connection (handle system accounts)
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
t.Run("basic command execution", func(t *testing.T) {
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
@@ -113,7 +123,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username
// testSSHInteractiveCommand tests interactive shell session.
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -178,7 +188,7 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// testSSHPortForwarding tests port forwarding compatibility.
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
testServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@@ -401,7 +411,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
t.Skip("Skipping SSH feature compatibility tests in short mode")
}
if runtime.GOOS == "windows" && isCI() {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
}
@@ -438,13 +448,13 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -468,7 +478,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
// testCommandWithFlags tests that commands with flags work properly
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Test ls with flags
cmd := exec.Command("ssh",
@@ -495,7 +505,7 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// testEnvironmentVariables tests that environment is properly set up
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
cmd := exec.Command("ssh",
"-i", keyFile,
@@ -522,7 +532,7 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// testExitCodes tests that exit codes are properly handled
func testExitCodes(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Test successful command (exit code 0)
cmd := exec.Command("ssh",
@@ -567,7 +577,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
}
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Set up SSH server with specific security settings
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
@@ -575,13 +585,13 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -652,7 +662,7 @@ func TestCrossPlatformCompatibility(t *testing.T) {
}
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
@@ -660,13 +670,13 @@ func TestCrossPlatformCompatibility(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -710,171 +720,3 @@ func TestCrossPlatformCompatibility(t *testing.T) {
t.Logf("Platform command output: %s", outputStr)
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
}
// getTestUsername returns an appropriate username for testing
func getTestUsername(t *testing.T) string {
if runtime.GOOS == "windows" {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Check if this is a system account that can't authenticate
if isSystemAccount(currentUser.Username) {
// In CI environments, create a test user; otherwise try Administrator
if isCI() {
if testUser := getOrCreateTestUser(t); testUser != "" {
return testUser
}
} else {
// Try Administrator first for local development
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := getOrCreateTestUser(t); testUser != "" {
return testUser
}
}
}
return currentUser.Username
}
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
return currentUser.Username
}
// isCI checks if we're running in a CI environment
func isCI() bool {
// Check standard CI environment variables
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// isSystemAccount checks if the user is a system account that can't authenticate
func isSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
var compatTestCreatedUsers = make(map[string]bool)
var compatTestUsersToCleanup []string
// registerTestUserCleanup registers a test user for cleanup
func registerTestUserCleanup(username string) {
if !compatTestCreatedUsers[username] {
compatTestCreatedUsers[username] = true
compatTestUsersToCleanup = append(compatTestUsersToCleanup, username)
}
}
// cleanupTestUsers removes all created test users
func cleanupTestUsers() {
for _, username := range compatTestUsersToCleanup {
removeWindowsTestUser(username)
}
compatTestUsersToCleanup = nil
compatTestCreatedUsers = make(map[string]bool)
}
// getOrCreateTestUser creates a test user on Windows if needed
func getOrCreateTestUser(t *testing.T) string {
testUsername := "netbird-test-user"
// Check if user already exists
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
// Try to create the user using PowerShell
if createWindowsTestUser(t, testUsername) {
// Register cleanup for the test user
registerTestUserCleanup(testUsername)
return testUsername
}
return ""
}
// removeWindowsTestUser removes a local user on Windows using PowerShell
func removeWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
// PowerShell command to remove a local user
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// createWindowsTestUser creates a local user on Windows using PowerShell
func createWindowsTestUser(t *testing.T, username string) bool {
if runtime.GOOS != "windows" {
return false
}
// PowerShell command to create a local user
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
t.Logf("Test user creation result: %s", string(output))
return true
}

View File

@@ -0,0 +1,610 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
cryptossh "golang.org/x/crypto/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
)
func TestJWTEnforcement(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT enforcement tests in short mode")
}
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
if err != nil {
t.Logf("Detection failed: %v", err)
}
t.Logf("Detected server type: %s", serverType)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
})
t.Run("allows_when_disabled", func(t *testing.T) {
serverConfigNoJWT := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
serverNoJWT := New(serverConfigNoJWT)
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
serverNoJWT.SetAllowRootLogin(true)
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
defer require.NoError(t, serverNoJWT.Stop())
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
require.NoError(t, err)
portNoJWT, err := strconv.Atoi(portStrNoJWT)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
assert.False(t, serverType.RequiresJWT())
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
require.NoError(t, err)
defer client.Close()
})
}
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64RawURLEncode(n),
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func base64RawURLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
// generateValidJWT creates a valid JWT token for testing
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
t.Helper()
addr := net.JoinHostPort(host, strconv.Itoa(port))
ctx := context.Background()
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
InsecureSkipVerify: true,
})
}
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
func TestJWTDetection(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT detection test in short mode")
}
jwksServer, _, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
assert.True(t, serverType.RequiresJWT())
}
func TestJWTFailClose(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT fail-close tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
tokenClaims jwt.MapClaims
}{
{
name: "blocks_token_missing_iat",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
},
},
{
name: "blocks_token_missing_sub",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_iss",
tokenClaims: jwt.MapClaims{
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_aud",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_issuer",
tokenClaims: jwt.MapClaims{
"iss": "wrong-issuer",
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_audience",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": "wrong-audience",
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_expired_token",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(-time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
MaxTokenAge: 3600,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(tokenString),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
assert.Error(t, err, "Authentication should fail (fail-close)")
})
}
}
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
func TestJWTAuthentication(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT authentication tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
token string
wantAuthOK bool
setupServer func(*Server)
testOperation func(*testing.T, *cryptossh.Client, string) error
wantOpSuccess bool
}{
{
name: "allows_shell_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
return session.Shell()
},
wantOpSuccess: true,
},
{
name: "rejects_invalid_token",
token: "invalid",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_shell_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_command_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("ls")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "allows_sftp_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
return session.RequestSubsystem("sftp")
},
wantOpSuccess: true,
},
{
name: "blocks_sftp_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
err = session.RequestSubsystem("sftp")
if err == nil {
err = session.Wait()
}
return err
},
wantOpSuccess: false,
},
{
name: "allows_port_forward_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowRemotePortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: true,
},
{
name: "blocks_port_forward_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowLocalPortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
if tc.setupServer != nil {
tc.setupServer(server)
}
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
var authMethods []cryptossh.AuthMethod
if tc.token == "valid" {
token := generateValidJWT(t, privateKey, issuer, audience)
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(token),
}
} else if tc.token == "invalid" {
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(invalidToken),
}
}
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: authMethods,
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed")
} else if err != nil {
t.Logf("Connection failed as expected: %v", err)
return
}
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
err = tc.testOperation(t, conn, serverAddr)
if tc.wantOpSuccess {
require.NoError(t, err, "Operation should succeed")
} else {
assert.Error(t, err, "Operation should fail")
}
})
}
}

View File

@@ -2,18 +2,27 @@ package server
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/version"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
@@ -27,6 +36,9 @@ const (
errExitSession = "exit session error: %v"
msgPrivilegedUserDisabled = "privileged user login is disabled"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
var (
@@ -69,7 +81,6 @@ func (e *UserNotFoundError) Unwrap() error {
}
// safeLogCommand returns a safe representation of the command for logging
// Only logs the first argument to avoid leaking sensitive information
func safeLogCommand(cmd []string) string {
if len(cmd) == 0 {
return "<empty>"
@@ -80,17 +91,14 @@ func safeLogCommand(cmd []string) string {
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
}
// sshConnectionState tracks the state of an SSH connection
type sshConnectionState struct {
hasActivePortForward bool
username string
remoteAddr string
}
// Server is the SSH server implementation
type Server struct {
sshServer *ssh.Server
authorizedKeys map[string]ssh.PublicKey
mu sync.RWMutex
hostKeyPEM []byte
sessions map[SessionKey]ssh.Session
@@ -100,30 +108,53 @@ type Server struct {
allowRemotePortForwarding bool
allowRootLogin bool
allowSFTP bool
jwtEnabled bool
netstackNet *netstack.Net
wgAddress wgaddr.Address
ifIdx int
remoteForwardListeners map[ForwardKey]net.Listener
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
}
// New creates an SSH server instance with the provided host key
func New(hostKeyPEM []byte) *Server {
return &Server{
type JWTConfig struct {
Issuer string
Audience string
KeysLocation string
MaxTokenAge int64
}
// Config contains all SSH server configuration options
type Config struct {
// JWT authentication configuration. If nil, JWT authentication is disabled
JWT *JWTConfig
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: hostKeyPEM,
authorizedKeys: make(map[string]ssh.PublicKey),
hostKeyPEM: config.HostKeyPEM,
sessions: make(map[SessionKey]ssh.Session),
remoteForwardListeners: make(map[ForwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
}
return s
}
// Start runs the SSH server, automatically detecting netstack vs standard networking
// Does all setup synchronously, then starts serving in a goroutine and returns immediately
// Start runs the SSH server
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -139,7 +170,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
sshServer, err := s.createSSHServer(ln.Addr())
if err != nil {
s.cleanupOnError(ln)
s.closeListener(ln)
return fmt.Errorf("create SSH server: %w", err)
}
@@ -154,7 +185,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return nil
}
// createListener creates a network listener based on netstack vs standard networking
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
@@ -173,22 +203,15 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L
return ln, addr.String(), nil
}
// closeListener safely closes a listener
func (s *Server) closeListener(ln net.Listener) {
if ln == nil {
return
}
if err := ln.Close(); err != nil {
log.Debugf("listener close error: %v", err)
}
}
// cleanupOnError cleans up resources when SSH server creation fails
func (s *Server) cleanupOnError(ln net.Listener) {
if s.ifIdx == 0 || ln == nil {
return
}
s.closeListener(ln)
}
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
@@ -207,28 +230,6 @@ func (s *Server) Stop() error {
return nil
}
// RemoveAuthorizedKey removes the SSH key for a peer
func (s *Server) RemoveAuthorizedKey(peer string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.authorizedKeys, peer)
}
// AddAuthorizedKey adds an SSH key for a peer
func (s *Server) AddAuthorizedKey(peer, newKey string) error {
s.mu.Lock()
defer s.mu.Unlock()
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
if err != nil {
return fmt.Errorf("parse key: %w", err)
}
s.authorizedKeys[peer] = parsedKey
return nil
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
@@ -243,34 +244,195 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr
}
// SetSocketFilter configures eBPF socket filtering for the SSH server
func (s *Server) SetSocketFilter(ifIdx int) {
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
if s.jwtValidator != nil && s.jwtExtractor != nil {
s.mu.RUnlock()
return nil
}
config := s.jwtConfig
s.mu.RUnlock()
if config == nil {
return fmt.Errorf("JWT config not set")
}
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
validator := jwt.NewValidator(
config.Issuer,
[]string{config.Audience},
config.KeysLocation,
true,
)
extractor := jwt.NewClaimsExtractor(
jwt.WithAudience(config.Audience),
)
s.mu.Lock()
defer s.mu.Unlock()
s.ifIdx = ifIdx
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
s.jwtValidator = validator
s.jwtExtractor = extractor
log.Infof("JWT validator initialized successfully")
return nil
}
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
s.mu.RLock()
defer s.mu.RUnlock()
jwtValidator := s.jwtValidator
jwtConfig := s.jwtConfig
s.mu.RUnlock()
for _, allowed := range s.authorizedKeys {
if ssh.KeysEqual(allowed, key) {
if ctx != nil {
log.Debugf("SSH key authentication successful for user %s from %s", ctx.User(), ctx.RemoteAddr())
if jwtValidator == nil {
return nil, fmt.Errorf("JWT validator not initialized")
}
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
if err != nil {
if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
}
return true
}
return nil, fmt.Errorf("validate token: %w", err)
}
if ctx != nil {
log.Warnf("SSH key authentication failed for user %s from %s: key not authorized (type: %s, fingerprint: %s)",
ctx.User(), ctx.RemoteAddr(), key.Type(), cryptossh.FingerprintSHA256(key))
if err := s.checkTokenAge(token, jwtConfig); err != nil {
return nil, err
}
return false
return token, nil
}
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
return nil
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
s.mu.RLock()
jwtExtractor := s.jwtExtractor
s.mu.RUnlock()
if jwtExtractor == nil {
userID := extractUserID(token)
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
}
userAuth, err := jwtExtractor.ToUserAuth(token)
if err != nil {
userID := extractUserID(token)
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
}
if !s.hasSSHAccess(&userAuth) {
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
}
return &userAuth, nil
}
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
return userAuth.UserId != ""
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("decode payload: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("parse claims: %w", err)
}
return claims, nil
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true
}
// markConnectionActivePortForward marks an SSH connection as having an active port forward
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -286,14 +448,12 @@ func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn,
}
}
// connectionCloseHandler cleans up connection state when SSH connections fail/close
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
// We can't extract the SSH connection from net.Conn directly
// Connection cleanup will happen during session cleanup or via timeout
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
}
// findSessionKeyByContext finds the session key by matching SSH connection context
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
if ctx == nil {
return "unknown"
@@ -319,14 +479,13 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
// Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil {
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
log.Debugf("using temporary session key for port forward tracking: %s", tempKey)
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
return tempKey
}
return "unknown"
}
// connectionValidator validates incoming connections based on source IP
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
s.mu.RLock()
netbirdNetwork := s.wgAddress.Network
@@ -340,8 +499,8 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
remoteAddr := conn.RemoteAddr()
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
if !ok {
log.Debugf("SSH connection from non-TCP address %s allowed", remoteAddr)
return conn
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
return nil
}
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
@@ -357,15 +516,14 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
}
if !netbirdNetwork.Contains(remoteIP) {
log.Warnf("SSH connection rejected from non-NetBird IP %s (allowed range: %s)", remoteIP, netbirdNetwork)
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
return nil
}
log.Debugf("SSH connection from %s allowed", remoteIP)
log.Infof("SSH connection from NetBird peer %s allowed", remoteIP)
return conn
}
// isShutdownError checks if the error is expected during normal shutdown
func isShutdownError(err error) bool {
if errors.Is(err, net.ErrClosed) {
return true
@@ -379,12 +537,16 @@ func isShutdownError(err error) bool {
return false
}
// createSSHServer creates and configures the SSH server
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
if err := enableUserSwitching(); err != nil {
log.Warnf("failed to enable user switching: %v", err)
}
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
if s.jwtEnabled {
serverVersion += " " + detection.JWTRequiredMarker
}
server := &ssh.Server{
Addr: addr.String(),
Handler: s.sessionHandler,
@@ -402,6 +564,11 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
},
ConnCallback: s.connectionValidator,
ConnectionFailedCallback: s.connectionCloseHandler,
Version: serverVersion,
}
if s.jwtEnabled {
server.PasswordHandler = s.passwordHandler
}
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
@@ -413,14 +580,12 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
return server, nil
}
// storeRemoteForwardListener stores a remote forward listener for cleanup
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln
}
// removeRemoteForwardListener removes and closes a remote forward listener
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
s.mu.Lock()
defer s.mu.Unlock()
@@ -438,7 +603,6 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
return true
}
// directTCPIPHandler handles direct-tcpip channel requests for local port forwarding with privilege validation
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
var payload struct {
Host string

View File

@@ -22,12 +22,6 @@ func TestServer_RootLoginRestriction(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
tests := []struct {
name string
allowRoot bool
@@ -117,10 +111,12 @@ func TestServer_RootLoginRestriction(t *testing.T) {
defer cleanup()
// Create server with specific configuration
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(tt.allowRoot)
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
// Test the userNameLookup method directly
user, err := server.userNameLookup(tt.username)
@@ -196,7 +192,11 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create server with specific configuration
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
@@ -234,17 +234,13 @@ func TestServer_PortConflictHandling(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -263,7 +259,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel1()
client1, err := sshclient.DialInsecure(ctx1, serverAddr, currentUser.Username)
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client1.Close()
@@ -274,7 +272,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
client2, err := sshclient.DialInsecure(ctx2, serverAddr, currentUser.Username)
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client2.Close()

View File

@@ -7,11 +7,9 @@ import (
"net/netip"
"os/user"
"runtime"
"strings"
"testing"
"time"
"github.com/gliderlabs/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
@@ -19,82 +17,15 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
)
func TestServer_AddAuthorizedKey(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
keys := map[string][]byte{}
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey(peer, string(remotePubKey))
require.NoError(t, err)
keys[peer] = remotePubKey
}
for peer, remotePubKey := range keys {
k, ok := server.authorizedKeys[peer]
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(cryptossh.MarshalAuthorizedKey(k))))
}
}
func TestServer_RemoveAuthorizedKey(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
require.NoError(t, err)
server.RemoveAuthorizedKey("remotePeer")
_, ok := server.authorizedKeys["remotePeer"]
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
}
func TestServer_PubKeyHandler(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
var keys []ssh.PublicKey
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
require.NoError(t, err)
err = server.AddAuthorizedKey(peer, string(remotePubKey))
require.NoError(t, err)
keys = append(keys, remoteParsedPubKey)
}
for _, key := range keys {
accepted := server.publicKeyHandler(nil, key)
assert.True(t, accepted, "SSH key should be accepted")
}
}
func TestServer_StartStop(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
serverConfig := &Config{
HostKeyPEM: key,
JWT: nil,
}
server := New(serverConfig)
err = server.Stop()
assert.NoError(t, err)
@@ -108,15 +39,13 @@ func TestSSHServerIntegration(t *testing.T) {
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with random port
server := New(hostKey)
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server in background
serverAddr := "127.0.0.1:0"
@@ -212,13 +141,13 @@ func TestSSHServerMultipleConnections(t *testing.T) {
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server
server := New(hostKey)
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
@@ -324,20 +253,12 @@ func TestSSHServerNoAuthMode(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate authorized key
authorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
authorizedPubKey, err := nbssh.GeneratePublicKey(authorizedPrivKey)
require.NoError(t, err)
// Generate unauthorized key (different from authorized)
unauthorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server with only one authorized key
server := New(hostKey)
err = server.AddAuthorizedKey("authorized-peer", string(authorizedPubKey))
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
@@ -377,8 +298,10 @@ func TestSSHServerNoAuthMode(t *testing.T) {
require.NoError(t, err)
}()
// Parse unauthorized private key
unauthorizedSigner, err := cryptossh.ParsePrivateKey(unauthorizedPrivKey)
// Generate a client private key for SSH protocol (server doesn't check it)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
@@ -390,17 +313,17 @@ func TestSSHServerNoAuthMode(t *testing.T) {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
// Try to connect with unauthorized key
// Try to connect with client key
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(unauthorizedSigner),
cryptossh.PublicKeys(clientSigner),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// This should succeed in no-auth mode
// This should succeed in no-auth mode (server doesn't verify keys)
conn, err := cryptossh.Dial("tcp", serverAddr, config)
assert.NoError(t, err, "Connection should succeed in no-auth mode")
if conn != nil {
@@ -412,7 +335,11 @@ func TestSSHServerStartStopCycle(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
serverAddr := "127.0.0.1:0"
// Test multiple start/stop cycles
@@ -485,8 +412,17 @@ func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server1 := New(hostKey)
server2 := New(hostKey)
serverConfig1 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server1 := New(serverConfig1)
serverConfig2 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server2 := New(serverConfig2)
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")

View File

@@ -35,17 +35,15 @@ func TestSSHServer_SFTPSubsystem(t *testing.T) {
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with SFTP enabled
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(true)
server.SetAllowRootLogin(true) // Allow root login for testing
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
server.SetAllowRootLogin(true)
// Start server
serverAddr := "127.0.0.1:0"
@@ -144,17 +142,15 @@ func TestSSHServer_SFTPDisabled(t *testing.T) {
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with SFTP disabled
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(false)
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)

View File

@@ -14,7 +14,6 @@ func StartTestServer(t *testing.T, server *Server) string {
errChan := make(chan error, 1)
go func() {
// Get a free port
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
errChan <- err
@@ -26,9 +25,12 @@ func StartTestServer(t *testing.T, server *Server) string {
return
}
started <- actualAddr
addrPort := netip.MustParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {