mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Add ssh authenatication with jwt (#4550)
This commit is contained in:
@@ -9,7 +9,6 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -21,15 +20,24 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
cleanupTestUsers()
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -50,13 +58,15 @@ func TestSSHServerCompatibility(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate OpenSSH-compatible keys for client
|
||||
clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey(t)
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKeyOpenSSH))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -73,7 +83,7 @@ func TestSSHServerCompatibility(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get appropriate user for SSH connection (handle system accounts)
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("basic command execution", func(t *testing.T) {
|
||||
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
|
||||
@@ -113,7 +123,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username
|
||||
// testSSHInteractiveCommand tests interactive shell session.
|
||||
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
@@ -178,7 +188,7 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||
// testSSHPortForwarding tests port forwarding compatibility.
|
||||
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
@@ -401,7 +411,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
t.Skip("Skipping SSH feature compatibility tests in short mode")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && isCI() {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
@@ -438,13 +448,13 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -468,7 +478,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
// testCommandWithFlags tests that commands with flags work properly
|
||||
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test ls with flags
|
||||
cmd := exec.Command("ssh",
|
||||
@@ -495,7 +505,7 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||
// testEnvironmentVariables tests that environment is properly set up
|
||||
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
@@ -522,7 +532,7 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||
// testExitCodes tests that exit codes are properly handled
|
||||
func testExitCodes(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test successful command (exit code 0)
|
||||
cmd := exec.Command("ssh",
|
||||
@@ -567,7 +577,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server with specific security settings
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
@@ -575,13 +585,13 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -652,7 +662,7 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
@@ -660,13 +670,13 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -710,171 +720,3 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
t.Logf("Platform command output: %s", outputStr)
|
||||
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
|
||||
}
|
||||
|
||||
// getTestUsername returns an appropriate username for testing
|
||||
func getTestUsername(t *testing.T) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Check if this is a system account that can't authenticate
|
||||
if isSystemAccount(currentUser.Username) {
|
||||
// In CI environments, create a test user; otherwise try Administrator
|
||||
if isCI() {
|
||||
if testUser := getOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
} else {
|
||||
// Try Administrator first for local development
|
||||
if _, err := user.Lookup("Administrator"); err == nil {
|
||||
return "Administrator"
|
||||
}
|
||||
if testUser := getOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
// isCI checks if we're running in a CI environment
|
||||
func isCI() bool {
|
||||
// Check standard CI environment variables
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
|
||||
hostname, err := os.Hostname()
|
||||
if err == nil && strings.HasPrefix(hostname, "runner") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isSystemAccount checks if the user is a system account that can't authenticate
|
||||
func isSystemAccount(username string) bool {
|
||||
systemAccounts := []string{
|
||||
"system",
|
||||
"NT AUTHORITY\\SYSTEM",
|
||||
"NT AUTHORITY\\LOCAL SERVICE",
|
||||
"NT AUTHORITY\\NETWORK SERVICE",
|
||||
}
|
||||
|
||||
for _, sysAccount := range systemAccounts {
|
||||
if strings.EqualFold(username, sysAccount) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var compatTestCreatedUsers = make(map[string]bool)
|
||||
var compatTestUsersToCleanup []string
|
||||
|
||||
// registerTestUserCleanup registers a test user for cleanup
|
||||
func registerTestUserCleanup(username string) {
|
||||
if !compatTestCreatedUsers[username] {
|
||||
compatTestCreatedUsers[username] = true
|
||||
compatTestUsersToCleanup = append(compatTestUsersToCleanup, username)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupTestUsers removes all created test users
|
||||
func cleanupTestUsers() {
|
||||
for _, username := range compatTestUsersToCleanup {
|
||||
removeWindowsTestUser(username)
|
||||
}
|
||||
compatTestUsersToCleanup = nil
|
||||
compatTestCreatedUsers = make(map[string]bool)
|
||||
}
|
||||
|
||||
// getOrCreateTestUser creates a test user on Windows if needed
|
||||
func getOrCreateTestUser(t *testing.T) string {
|
||||
testUsername := "netbird-test-user"
|
||||
|
||||
// Check if user already exists
|
||||
if _, err := user.Lookup(testUsername); err == nil {
|
||||
return testUsername
|
||||
}
|
||||
|
||||
// Try to create the user using PowerShell
|
||||
if createWindowsTestUser(t, testUsername) {
|
||||
// Register cleanup for the test user
|
||||
registerTestUserCleanup(testUsername)
|
||||
return testUsername
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// removeWindowsTestUser removes a local user on Windows using PowerShell
|
||||
func removeWindowsTestUser(username string) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
// PowerShell command to remove a local user
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
Remove-LocalUser -Name "%s" -ErrorAction Stop
|
||||
Write-Output "User removed successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*cannot be found*") {
|
||||
Write-Output "User not found (already removed)"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
}
|
||||
}
|
||||
`, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
|
||||
} else {
|
||||
log.Printf("Test user %s cleanup result: %s", username, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// createWindowsTestUser creates a local user on Windows using PowerShell
|
||||
func createWindowsTestUser(t *testing.T, username string) bool {
|
||||
if runtime.GOOS != "windows" {
|
||||
return false
|
||||
}
|
||||
|
||||
// PowerShell command to create a local user
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
|
||||
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
|
||||
Add-LocalGroupMember -Group "Users" -Member "%s"
|
||||
Write-Output "User created successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*already exists*") {
|
||||
Write-Output "User already exists"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
`, username, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
|
||||
return false
|
||||
}
|
||||
|
||||
t.Logf("Test user creation result: %s", string(output))
|
||||
return true
|
||||
}
|
||||
|
||||
610
client/ssh/server/jwt_test.go
Normal file
610
client/ssh/server/jwt_test.go
Normal file
@@ -0,0 +1,610 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
)
|
||||
|
||||
func TestJWTEnforcement(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT enforcement tests in short mode")
|
||||
}
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("blocks_without_jwt", func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
if err != nil {
|
||||
t.Logf("Detection failed: %v", err)
|
||||
}
|
||||
t.Logf("Detected server type: %s", serverType)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
|
||||
})
|
||||
|
||||
t.Run("allows_when_disabled", func(t *testing.T) {
|
||||
serverConfigNoJWT := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
serverNoJWT := New(serverConfigNoJWT)
|
||||
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
|
||||
serverNoJWT.SetAllowRootLogin(true)
|
||||
|
||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||
defer require.NoError(t, serverNoJWT.Stop())
|
||||
|
||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||
require.NoError(t, err)
|
||||
portNoJWT, err := strconv.Atoi(portStrNoJWT)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
|
||||
assert.False(t, serverType.RequiresJWT())
|
||||
|
||||
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64RawURLEncode(n),
|
||||
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func base64RawURLEncode(data []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// generateValidJWT creates a valid JWT token for testing
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
|
||||
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
|
||||
t.Helper()
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
ctx := context.Background()
|
||||
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
}
|
||||
|
||||
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
|
||||
func TestJWTDetection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT detection test in short mode")
|
||||
}
|
||||
|
||||
jwksServer, _, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
|
||||
assert.True(t, serverType.RequiresJWT())
|
||||
}
|
||||
|
||||
func TestJWTFailClose(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT fail-close tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
tokenClaims jwt.MapClaims
|
||||
}{
|
||||
{
|
||||
name: "blocks_token_missing_iat",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_sub",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_iss",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_aud",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_issuer",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": "wrong-issuer",
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_audience",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": "wrong-audience",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_expired_token",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
MaxTokenAge: 3600,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.Password(tokenString),
|
||||
},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
assert.Error(t, err, "Authentication should fail (fail-close)")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
|
||||
func TestJWTAuthentication(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT authentication tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
token string
|
||||
wantAuthOK bool
|
||||
setupServer func(*Server)
|
||||
testOperation func(*testing.T, *cryptossh.Client, string) error
|
||||
wantOpSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "allows_shell_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
return session.Shell()
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "rejects_invalid_token",
|
||||
token: "invalid",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_shell_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_command_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("ls")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_sftp_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
return session.RequestSubsystem("sftp")
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_sftp_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
err = session.RequestSubsystem("sftp")
|
||||
if err == nil {
|
||||
err = session.Wait()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_port_forward_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowRemotePortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_port_forward_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowLocalPortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
if tc.setupServer != nil {
|
||||
tc.setupServer(server)
|
||||
}
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
var authMethods []cryptossh.AuthMethod
|
||||
if tc.token == "valid" {
|
||||
token := generateValidJWT(t, privateKey, issuer, audience)
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(token),
|
||||
}
|
||||
} else if tc.token == "invalid" {
|
||||
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(invalidToken),
|
||||
}
|
||||
}
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if tc.wantAuthOK {
|
||||
require.NoError(t, err, "JWT authentication should succeed")
|
||||
} else if err != nil {
|
||||
t.Logf("Connection failed as expected: %v", err)
|
||||
return
|
||||
}
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err = tc.testOperation(t, conn, serverAddr)
|
||||
if tc.wantOpSuccess {
|
||||
require.NoError(t, err, "Operation should succeed")
|
||||
} else {
|
||||
assert.Error(t, err, "Operation should fail")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,18 +2,27 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
@@ -27,6 +36,9 @@ const (
|
||||
errExitSession = "exit session error: %v"
|
||||
|
||||
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
||||
|
||||
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
||||
DefaultJWTMaxTokenAge = 5 * 60
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,7 +81,6 @@ func (e *UserNotFoundError) Unwrap() error {
|
||||
}
|
||||
|
||||
// safeLogCommand returns a safe representation of the command for logging
|
||||
// Only logs the first argument to avoid leaking sensitive information
|
||||
func safeLogCommand(cmd []string) string {
|
||||
if len(cmd) == 0 {
|
||||
return "<empty>"
|
||||
@@ -80,17 +91,14 @@ func safeLogCommand(cmd []string) string {
|
||||
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
||||
}
|
||||
|
||||
// sshConnectionState tracks the state of an SSH connection
|
||||
type sshConnectionState struct {
|
||||
hasActivePortForward bool
|
||||
username string
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
// Server is the SSH server implementation
|
||||
type Server struct {
|
||||
sshServer *ssh.Server
|
||||
authorizedKeys map[string]ssh.PublicKey
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
sessions map[SessionKey]ssh.Session
|
||||
@@ -100,30 +108,53 @@ type Server struct {
|
||||
allowRemotePortForwarding bool
|
||||
allowRootLogin bool
|
||||
allowSFTP bool
|
||||
jwtEnabled bool
|
||||
|
||||
netstackNet *netstack.Net
|
||||
|
||||
wgAddress wgaddr.Address
|
||||
ifIdx int
|
||||
|
||||
remoteForwardListeners map[ForwardKey]net.Listener
|
||||
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
|
||||
|
||||
jwtValidator *jwt.Validator
|
||||
jwtExtractor *jwt.ClaimsExtractor
|
||||
jwtConfig *JWTConfig
|
||||
}
|
||||
|
||||
// New creates an SSH server instance with the provided host key
|
||||
func New(hostKeyPEM []byte) *Server {
|
||||
return &Server{
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
Audience string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
}
|
||||
|
||||
// Config contains all SSH server configuration options
|
||||
type Config struct {
|
||||
// JWT authentication configuration. If nil, JWT authentication is disabled
|
||||
JWT *JWTConfig
|
||||
|
||||
// HostKey is the SSH server host key in PEM format
|
||||
HostKeyPEM []byte
|
||||
}
|
||||
|
||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||
// If jwtConfig is nil, JWT authentication is disabled
|
||||
func New(config *Config) *Server {
|
||||
s := &Server{
|
||||
mu: sync.RWMutex{},
|
||||
hostKeyPEM: hostKeyPEM,
|
||||
authorizedKeys: make(map[string]ssh.PublicKey),
|
||||
hostKeyPEM: config.HostKeyPEM,
|
||||
sessions: make(map[SessionKey]ssh.Session),
|
||||
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||
jwtEnabled: config.JWT != nil,
|
||||
jwtConfig: config.JWT,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start runs the SSH server, automatically detecting netstack vs standard networking
|
||||
// Does all setup synchronously, then starts serving in a goroutine and returns immediately
|
||||
// Start runs the SSH server
|
||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -139,7 +170,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
|
||||
sshServer, err := s.createSSHServer(ln.Addr())
|
||||
if err != nil {
|
||||
s.cleanupOnError(ln)
|
||||
s.closeListener(ln)
|
||||
return fmt.Errorf("create SSH server: %w", err)
|
||||
}
|
||||
|
||||
@@ -154,7 +185,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createListener creates a network listener based on netstack vs standard networking
|
||||
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
|
||||
if s.netstackNet != nil {
|
||||
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||
@@ -173,22 +203,15 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L
|
||||
return ln, addr.String(), nil
|
||||
}
|
||||
|
||||
// closeListener safely closes a listener
|
||||
func (s *Server) closeListener(ln net.Listener) {
|
||||
if ln == nil {
|
||||
return
|
||||
}
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("listener close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupOnError cleans up resources when SSH server creation fails
|
||||
func (s *Server) cleanupOnError(ln net.Listener) {
|
||||
if s.ifIdx == 0 || ln == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.closeListener(ln)
|
||||
}
|
||||
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
@@ -207,28 +230,6 @@ func (s *Server) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAuthorizedKey removes the SSH key for a peer
|
||||
func (s *Server) RemoveAuthorizedKey(peer string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.authorizedKeys, peer)
|
||||
}
|
||||
|
||||
// AddAuthorizedKey adds an SSH key for a peer
|
||||
func (s *Server) AddAuthorizedKey(peer, newKey string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse key: %w", err)
|
||||
}
|
||||
|
||||
s.authorizedKeys[peer] = parsedKey
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace networking
|
||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
@@ -243,34 +244,195 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// SetSocketFilter configures eBPF socket filtering for the SSH server
|
||||
func (s *Server) SetSocketFilter(ifIdx int) {
|
||||
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
s.mu.RLock()
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
s.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
config := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
if config == nil {
|
||||
return fmt.Errorf("JWT config not set")
|
||||
}
|
||||
|
||||
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
|
||||
|
||||
validator := jwt.NewValidator(
|
||||
config.Issuer,
|
||||
[]string{config.Audience},
|
||||
config.KeysLocation,
|
||||
true,
|
||||
)
|
||||
|
||||
extractor := jwt.NewClaimsExtractor(
|
||||
jwt.WithAudience(config.Audience),
|
||||
)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.ifIdx = ifIdx
|
||||
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.jwtValidator = validator
|
||||
s.jwtExtractor = extractor
|
||||
|
||||
log.Infof("JWT validator initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
jwtValidator := s.jwtValidator
|
||||
jwtConfig := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
for _, allowed := range s.authorizedKeys {
|
||||
if ssh.KeysEqual(allowed, key) {
|
||||
if ctx != nil {
|
||||
log.Debugf("SSH key authentication successful for user %s from %s", ctx.User(), ctx.RemoteAddr())
|
||||
if jwtValidator == nil {
|
||||
return nil, fmt.Errorf("JWT validator not initialized")
|
||||
}
|
||||
|
||||
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
|
||||
if err != nil {
|
||||
if jwtConfig != nil {
|
||||
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
|
||||
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return nil, fmt.Errorf("validate token: %w", err)
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
log.Warnf("SSH key authentication failed for user %s from %s: key not authorized (type: %s, fingerprint: %s)",
|
||||
ctx.User(), ctx.RemoteAddr(), key.Type(), cryptossh.FingerprintSHA256(key))
|
||||
if err := s.checkTokenAge(token, jwtConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return false
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||
if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
||||
}
|
||||
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
tokenAge := time.Since(issuedAt)
|
||||
maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
|
||||
if tokenAge > maxAge {
|
||||
userID := getUserIDFromClaims(claims)
|
||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
|
||||
s.mu.RLock()
|
||||
jwtExtractor := s.jwtExtractor
|
||||
s.mu.RUnlock()
|
||||
|
||||
if jwtExtractor == nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
|
||||
}
|
||||
|
||||
userAuth, err := jwtExtractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
|
||||
}
|
||||
|
||||
if !s.hasSSHAccess(&userAuth) {
|
||||
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
|
||||
}
|
||||
|
||||
return &userAuth, nil
|
||||
}
|
||||
|
||||
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
|
||||
return userAuth.UserId != ""
|
||||
}
|
||||
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("parse claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.validateJWTToken(password)
|
||||
if err != nil {
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
userAuth, err := s.extractAndValidateUser(token)
|
||||
if err != nil {
|
||||
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
||||
return true
|
||||
}
|
||||
|
||||
// markConnectionActivePortForward marks an SSH connection as having an active port forward
|
||||
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -286,14 +448,12 @@ func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn,
|
||||
}
|
||||
}
|
||||
|
||||
// connectionCloseHandler cleans up connection state when SSH connections fail/close
|
||||
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
|
||||
// We can't extract the SSH connection from net.Conn directly
|
||||
// Connection cleanup will happen during session cleanup or via timeout
|
||||
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
// findSessionKeyByContext finds the session key by matching SSH connection context
|
||||
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||
if ctx == nil {
|
||||
return "unknown"
|
||||
@@ -319,14 +479,13 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||
// Return a temporary key that we'll fix up later
|
||||
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
||||
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
||||
log.Debugf("using temporary session key for port forward tracking: %s", tempKey)
|
||||
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
|
||||
return tempKey
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// connectionValidator validates incoming connections based on source IP
|
||||
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
s.mu.RLock()
|
||||
netbirdNetwork := s.wgAddress.Network
|
||||
@@ -340,8 +499,8 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
remoteAddr := conn.RemoteAddr()
|
||||
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
log.Debugf("SSH connection from non-TCP address %s allowed", remoteAddr)
|
||||
return conn
|
||||
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
@@ -357,15 +516,14 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
}
|
||||
|
||||
if !netbirdNetwork.Contains(remoteIP) {
|
||||
log.Warnf("SSH connection rejected from non-NetBird IP %s (allowed range: %s)", remoteIP, netbirdNetwork)
|
||||
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("SSH connection from %s allowed", remoteIP)
|
||||
log.Infof("SSH connection from NetBird peer %s allowed", remoteIP)
|
||||
return conn
|
||||
}
|
||||
|
||||
// isShutdownError checks if the error is expected during normal shutdown
|
||||
func isShutdownError(err error) bool {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
@@ -379,12 +537,16 @@ func isShutdownError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// createSSHServer creates and configures the SSH server
|
||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
if err := enableUserSwitching(); err != nil {
|
||||
log.Warnf("failed to enable user switching: %v", err)
|
||||
}
|
||||
|
||||
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
|
||||
if s.jwtEnabled {
|
||||
serverVersion += " " + detection.JWTRequiredMarker
|
||||
}
|
||||
|
||||
server := &ssh.Server{
|
||||
Addr: addr.String(),
|
||||
Handler: s.sessionHandler,
|
||||
@@ -402,6 +564,11 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
},
|
||||
ConnCallback: s.connectionValidator,
|
||||
ConnectionFailedCallback: s.connectionCloseHandler,
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
if s.jwtEnabled {
|
||||
server.PasswordHandler = s.passwordHandler
|
||||
}
|
||||
|
||||
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
|
||||
@@ -413,14 +580,12 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// storeRemoteForwardListener stores a remote forward listener for cleanup
|
||||
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.remoteForwardListeners[key] = ln
|
||||
}
|
||||
|
||||
// removeRemoteForwardListener removes and closes a remote forward listener
|
||||
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -438,7 +603,6 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// directTCPIPHandler handles direct-tcpip channel requests for local port forwarding with privilege validation
|
||||
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
||||
var payload struct {
|
||||
Host string
|
||||
|
||||
@@ -22,12 +22,6 @@ func TestServer_RootLoginRestriction(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowRoot bool
|
||||
@@ -117,10 +111,12 @@ func TestServer_RootLoginRestriction(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
// Create server with specific configuration
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(tt.allowRoot)
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test the userNameLookup method directly
|
||||
user, err := server.userNameLookup(tt.username)
|
||||
@@ -196,7 +192,11 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create server with specific configuration
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
|
||||
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
|
||||
|
||||
@@ -234,17 +234,13 @@ func TestServer_PortConflictHandling(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -263,7 +259,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
client1, err := sshclient.DialInsecure(ctx1, serverAddr, currentUser.Username)
|
||||
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client1.Close()
|
||||
@@ -274,7 +272,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
client2, err := sshclient.DialInsecure(ctx2, serverAddr, currentUser.Username)
|
||||
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client2.Close()
|
||||
|
||||
@@ -7,11 +7,9 @@ import (
|
||||
"net/netip"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
@@ -19,82 +17,15 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestServer_AddAuthorizedKey(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
server := New(key)
|
||||
|
||||
keys := map[string][]byte{}
|
||||
for i := 0; i < 10; i++ {
|
||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
||||
require.NoError(t, err)
|
||||
keys[peer] = remotePubKey
|
||||
}
|
||||
|
||||
for peer, remotePubKey := range keys {
|
||||
k, ok := server.authorizedKeys[peer]
|
||||
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
|
||||
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(cryptossh.MarshalAuthorizedKey(k))))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RemoveAuthorizedKey(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
server := New(key)
|
||||
|
||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
server.RemoveAuthorizedKey("remotePeer")
|
||||
|
||||
_, ok := server.authorizedKeys["remotePeer"]
|
||||
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
|
||||
}
|
||||
|
||||
func TestServer_PubKeyHandler(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
server := New(key)
|
||||
|
||||
var keys []ssh.PublicKey
|
||||
for i := 0; i < 10; i++ {
|
||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
||||
require.NoError(t, err)
|
||||
keys = append(keys, remoteParsedPubKey)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
accepted := server.publicKeyHandler(nil, key)
|
||||
assert.True(t, accepted, "SSH key should be accepted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_StartStop(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(key)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: key,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
@@ -108,15 +39,13 @@ func TestSSHServerIntegration(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with random port
|
||||
server := New(hostKey)
|
||||
|
||||
// Add client's public key as authorized
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server in background
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -212,13 +141,13 @@ func TestSSHServerMultipleConnections(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
server := New(hostKey)
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -324,20 +253,12 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate authorized key
|
||||
authorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
authorizedPubKey, err := nbssh.GeneratePublicKey(authorizedPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate unauthorized key (different from authorized)
|
||||
unauthorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with only one authorized key
|
||||
server := New(hostKey)
|
||||
err = server.AddAuthorizedKey("authorized-peer", string(authorizedPubKey))
|
||||
require.NoError(t, err)
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -377,8 +298,10 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse unauthorized private key
|
||||
unauthorizedSigner, err := cryptossh.ParsePrivateKey(unauthorizedPrivKey)
|
||||
// Generate a client private key for SSH protocol (server doesn't check it)
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
@@ -390,17 +313,17 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
// Try to connect with unauthorized key
|
||||
// Try to connect with client key
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(unauthorizedSigner),
|
||||
cryptossh.PublicKeys(clientSigner),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// This should succeed in no-auth mode
|
||||
// This should succeed in no-auth mode (server doesn't verify keys)
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
assert.NoError(t, err, "Connection should succeed in no-auth mode")
|
||||
if conn != nil {
|
||||
@@ -412,7 +335,11 @@ func TestSSHServerStartStopCycle(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
serverAddr := "127.0.0.1:0"
|
||||
|
||||
// Test multiple start/stop cycles
|
||||
@@ -485,8 +412,17 @@ func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server1 := New(hostKey)
|
||||
server2 := New(hostKey)
|
||||
serverConfig1 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server1 := New(serverConfig1)
|
||||
|
||||
serverConfig2 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server2 := New(serverConfig2)
|
||||
|
||||
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
|
||||
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
|
||||
|
||||
@@ -35,17 +35,15 @@ func TestSSHServer_SFTPSubsystem(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP enabled
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(true)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
|
||||
// Add client's public key as authorized
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -144,17 +142,15 @@ func TestSSHServer_SFTPDisabled(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP disabled
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(false)
|
||||
|
||||
// Add client's public key as authorized
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
|
||||
@@ -14,7 +14,6 @@ func StartTestServer(t *testing.T, server *Server) string {
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// Get a free port
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
@@ -26,9 +25,12 @@ func StartTestServer(t *testing.T, server *Server) string {
|
||||
return
|
||||
}
|
||||
|
||||
started <- actualAddr
|
||||
addrPort := netip.MustParseAddrPort(actualAddr)
|
||||
errChan <- server.Start(context.Background(), addrPort)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
|
||||
Reference in New Issue
Block a user