Add ssh authenatication with jwt (#4550)

This commit is contained in:
Viktor Liu
2025-10-07 23:38:27 +02:00
committed by GitHub
parent 7e0bbaaa3c
commit d9efe4e944
50 changed files with 4429 additions and 2336 deletions

View File

@@ -21,6 +21,8 @@ import (
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
)
const (
@@ -40,12 +42,10 @@ type Client struct {
windowsStdinMode uint32 // nolint:unused
}
// Close terminates the SSH connection
func (c *Client) Close() error {
return c.client.Close()
}
// OpenTerminal opens an interactive terminal session
func (c *Client) OpenTerminal(ctx context.Context) error {
session, err := c.client.NewSession()
if err != nil {
@@ -259,43 +259,29 @@ func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error
return session, cleanup, nil
}
// Dial connects to the given ssh server with proper host key verification
func Dial(ctx context.Context, addr, user string) (*Client, error) {
hostKeyCallback, err := createHostKeyCallback(addr)
if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err)
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
func getDefaultDaemonAddr() string {
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
return addr
}
config := &ssh.ClientConfig{
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: hostKeyCallback,
if runtime.GOOS == "windows" {
return DefaultDaemonAddrWindows
}
return dial(ctx, "tcp", addr, config)
}
// DialInsecure connects to the given ssh server without host key verification (for testing only)
func DialInsecure(ctx context.Context, addr, user string) (*Client, error) {
config := &ssh.ClientConfig{
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 - Only used for tests
}
return dial(ctx, "tcp", addr, config)
return DefaultDaemonAddr
}
// DialOptions contains options for SSH connections
type DialOptions struct {
KnownHostsFile string
IdentityFile string
DaemonAddr string
KnownHostsFile string
IdentityFile string
DaemonAddr string
SkipCachedToken bool
InsecureSkipVerify bool
}
// DialWithOptions connects to the given ssh server with specified options
func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
hostKeyCallback, err := createHostKeyCallbackWithOptions(addr, opts)
// Dial connects to the given ssh server with specified options
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
hostKeyCallback, err := createHostKeyCallback(opts)
if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err)
}
@@ -306,7 +292,6 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
HostKeyCallback: hostKeyCallback,
}
// Add SSH key authentication if identity file is specified
if opts.IdentityFile != "" {
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
if err != nil {
@@ -315,11 +300,16 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
config.Auth = append(config.Auth, authMethod)
}
return dial(ctx, "tcp", addr, config)
daemonAddr := opts.DaemonAddr
if daemonAddr == "" {
daemonAddr = getDefaultDaemonAddr()
}
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
}
// dial establishes an SSH connection
func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
// dialSSH establishes an SSH connection without JWT authentication
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
@@ -340,143 +330,84 @@ func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (
}, nil
}
// createHostKeyCallback creates a host key verification callback that checks daemon first, then known_hosts files
func createHostKeyCallback(addr string) (ssh.HostKeyCallback, error) {
daemonAddr := os.Getenv("NB_DAEMON_ADDR")
if daemonAddr == "" {
if runtime.GOOS == "windows" {
daemonAddr = DefaultDaemonAddrWindows
} else {
daemonAddr = DefaultDaemonAddr
}
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("parse address %s: %w", addr, err)
}
return createHostKeyCallbackWithDaemonAddr(addr, daemonAddr)
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
return nil, fmt.Errorf("SSH server detection failed: %w", err)
}
if !serverType.RequiresJWT() {
return dialSSH(ctx, network, addr, config)
}
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel()
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
if err != nil {
return nil, fmt.Errorf("request JWT token: %w", err)
}
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
return dialSSH(ctx, network, addr, configWithJWT)
}
// createHostKeyCallbackWithDaemonAddr creates a host key verification callback with specified daemon address
func createHostKeyCallbackWithDaemonAddr(addr, daemonAddr string) (ssh.HostKeyCallback, error) {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// First try to get host key from NetBird daemon
if err := verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr); err == nil {
return nil
}
// requestJWTToken requests a JWT token from the NetBird daemon
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
conn, err := connectToDaemon(daemonAddr)
if err != nil {
return "", fmt.Errorf("connect to daemon: %w", err)
}
defer conn.Close()
// Fallback to known_hosts files
knownHostsFiles := getKnownHostsFiles()
var hostKeyCallbacks []ssh.HostKeyCallback
for _, file := range knownHostsFiles {
if callback, err := knownhosts.New(file); err == nil {
hostKeyCallbacks = append(hostKeyCallbacks, callback)
}
}
// Try each known_hosts callback
for _, callback := range hostKeyCallbacks {
if err := callback(hostname, remote, key); err == nil {
return nil
}
}
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
}, nil
client := proto.NewDaemonServiceClient(conn)
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache)
}
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
client, err := connectToDaemon(daemonAddr)
conn, err := connectToDaemon(daemonAddr)
if err != nil {
return err
}
defer func() {
if err := client.Close(); err != nil {
if err := conn.Close(); err != nil {
log.Debugf("daemon connection close error: %v", err)
}
}()
addresses := buildAddressList(hostname, remote)
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses)
return verifyKeyWithDaemon(client, addresses, key)
client := proto.NewDaemonServiceClient(conn)
verifier := nbssh.NewDaemonHostKeyVerifier(client)
callback := nbssh.CreateHostKeyCallback(verifier)
return callback(hostname, remote, key)
}
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
addr := strings.TrimPrefix(daemonAddr, "tcp://")
conn, err := grpc.DialContext(
ctx,
strings.TrimPrefix(daemonAddr, "tcp://"),
conn, err := grpc.NewClient(
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
if err != nil {
log.Debugf("failed to connect to NetBird daemon at %s: %v", daemonAddr, err)
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
}
return conn, nil
}
func buildAddressList(hostname string, remote net.Addr) []string {
addresses := []string{hostname}
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
if host != hostname {
addresses = append(addresses, host)
}
}
return addresses
}
func verifyKeyWithDaemon(conn *grpc.ClientConn, addresses []string, key ssh.PublicKey) error {
client := proto.NewDaemonServiceClient(conn)
for _, addr := range addresses {
if err := checkAddressKey(client, addr, key); err == nil {
return nil
}
}
return fmt.Errorf("SSH host key not found or does not match in NetBird daemon")
}
func checkAddressKey(client proto.DaemonServiceClient, addr string, key ssh.PublicKey) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: addr,
})
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
if err != nil {
log.Debugf("daemon query error for %s: %v", addr, err)
return err
}
if !response.GetFound() {
log.Debugf("SSH host key not found in daemon for address: %s", addr)
return fmt.Errorf("key not found")
}
return compareKeys(response.GetSshHostKey(), key, addr)
}
func compareKeys(storedKeyData []byte, presentedKey ssh.PublicKey, addr string) error {
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
if err != nil {
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
return err
}
if presentedKey.Type() == storedKey.Type() && string(presentedKey.Marshal()) == string(storedKey.Marshal()) {
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
return nil
}
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), presentedKey.Type())
return fmt.Errorf("key mismatch")
}
// getKnownHostsFiles returns paths to known_hosts files in order of preference
func getKnownHostsFiles() []string {
var files []string
@@ -503,8 +434,12 @@ func getKnownHostsFiles() []string {
return files
}
// createHostKeyCallbackWithOptions creates a host key verification callback with custom options
func createHostKeyCallbackWithOptions(addr string, opts DialOptions) (ssh.HostKeyCallback, error) {
// createHostKeyCallback creates a host key verification callback
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
if opts.InsecureSkipVerify {
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
return nil

View File

@@ -7,29 +7,36 @@ import (
"io"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
cleanupTestUsers()
testutil.CleanupTestUsers()
os.Exit(code)
}
@@ -39,19 +46,14 @@ func TestSSHClient_DialWithKey(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create and start server
server := sshserver.New(hostKey)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
@@ -62,8 +64,10 @@ func TestSSHClient_DialWithKey(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client.Close()
@@ -75,7 +79,7 @@ func TestSSHClient_DialWithKey(t *testing.T) {
}
func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && isCI() {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
}
@@ -129,20 +133,16 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
}()
// Generate client key for multiple connections
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("multi-peer", string(clientPubKey))
require.NoError(t, err)
const numClients = 3
clients := make([]*Client, numClients)
for i := 0; i < numClients; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i))
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i), DialOptions{
InsecureSkipVerify: true,
})
cancel()
require.NoError(t, err, "Client %d should connect successfully", i)
clients[i] = client
@@ -161,19 +161,14 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
require.NoError(t, err)
}()
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("cancel-peer", string(clientPubKey))
require.NoError(t, err)
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := getCurrentUsername()
_, err = DialInsecure(ctx, serverAddr, currentUser)
currentUser := testutil.GetTestUsername(t)
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil {
// Check for actual timeout-related errors rather than string matching
assert.True(t,
@@ -187,8 +182,10 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
@@ -214,7 +211,11 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
server := sshserver.New(hostKey)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
serverAddr := sshserver.StartTestServer(t, server)
@@ -226,10 +227,12 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
currentUser := testutil.GetTestUsername(t)
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
client, err := DialInsecure(ctx, serverAddr, currentUser)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
assert.NoError(t, err)
if client != nil {
require.NoError(t, client.Close(), "Client should close without error")
@@ -282,24 +285,22 @@ func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Clie
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := sshserver.New(hostKey)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
return server, serverAddr, client
@@ -361,18 +362,14 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := sshserver.New(hostKey)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowLocalPortForwarding(true)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
@@ -387,11 +384,13 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
require.NoError(t, err)
// Skip if running as system account that can't do port forwarding
if isSystemAccount(realUser) {
if testutil.IsSystemAccount(realUser) {
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
}
client, err := DialInsecure(ctx, serverAddr, realUser)
client, err := Dial(ctx, serverAddr, realUser, DialOptions{
InsecureSkipVerify: true, // Skip host key verification for test
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
@@ -478,180 +477,6 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
assert.Equal(t, expectedResponse, string(response))
}
// getCurrentUsername returns the current username for SSH connections
func getCurrentUsername() string {
if runtime.GOOS == "windows" {
if currentUser, err := user.Current(); err == nil {
// Check if this is a system account that can't authenticate
if isSystemAccount(currentUser.Username) {
// In CI environments, create a test user; otherwise try Administrator
if isCI() {
if testUser := getOrCreateTestUser(); testUser != "" {
return testUser
}
} else {
// Try Administrator first for local development
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := getOrCreateTestUser(); testUser != "" {
return testUser
}
}
}
// On Windows, return the full domain\username for proper authentication
return currentUser.Username
}
}
if username := os.Getenv("USER"); username != "" {
return username
}
if currentUser, err := user.Current(); err == nil {
return currentUser.Username
}
return "test-user"
}
// isCI checks if we're running in GitHub Actions CI
func isCI() bool {
// Check standard CI environment variables
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// getOrCreateTestUser creates a test user on Windows if needed
func getOrCreateTestUser() string {
testUsername := "netbird-test-user"
// Check if user already exists
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
// Try to create the user using PowerShell
if createWindowsTestUser(testUsername) {
// Register cleanup for the test user
registerTestUserCleanup(testUsername)
return testUsername
}
return ""
}
var createdTestUsers = make(map[string]bool)
var testUsersToCleanup []string
// registerTestUserCleanup registers a test user for cleanup
func registerTestUserCleanup(username string) {
if !createdTestUsers[username] {
createdTestUsers[username] = true
testUsersToCleanup = append(testUsersToCleanup, username)
}
}
// cleanupTestUsers removes all created test users
func cleanupTestUsers() {
for _, username := range testUsersToCleanup {
removeWindowsTestUser(username)
}
testUsersToCleanup = nil
createdTestUsers = make(map[string]bool)
}
// removeWindowsTestUser removes a local user on Windows using PowerShell
func removeWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
// PowerShell command to remove a local user
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// createWindowsTestUser creates a local user on Windows using PowerShell
func createWindowsTestUser(username string) bool {
if runtime.GOOS != "windows" {
return false
}
// PowerShell command to create a local user
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
log.Printf("Test user creation result: %s", string(output))
return true
}
// isSystemAccount checks if the user is a system account that can't authenticate
func isSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
func getRealCurrentUser() (string, error) {
if runtime.GOOS == "windows" {

167
client/ssh/common.go Normal file
View File

@@ -0,0 +1,167 @@
package ssh
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/proto"
)
const (
NetBirdSSHConfigFile = "99-netbird.conf"
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
WindowsSSHConfigDir = "ssh/ssh_config.d"
)
var (
// ErrPeerNotFound indicates the peer was not found in the network
ErrPeerNotFound = errors.New("peer not found in network")
// ErrNoStoredKey indicates the peer has no stored SSH host key
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
)
// HostKeyVerifier provides SSH host key verification
type HostKeyVerifier interface {
VerifySSHHostKey(peerAddress string, key []byte) error
}
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
type DaemonHostKeyVerifier struct {
client proto.DaemonServiceClient
}
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
return &DaemonHostKeyVerifier{
client: client,
}
}
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: peerAddress,
})
if err != nil {
return err
}
if !response.GetFound() {
return ErrPeerNotFound
}
storedKeyData := response.GetSshHostKey()
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
}
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool) (string, error) {
authResponse, err := client.RequestJWTAuth(ctx, &proto.RequestJWTAuthRequest{})
if err != nil {
return "", fmt.Errorf("request JWT auth: %w", err)
}
if useCache && authResponse.CachedToken != "" {
log.Debug("Using cached authentication token")
return authResponse.CachedToken, nil
}
if stderr != nil {
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
if authResponse.UserCode != "" {
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
}
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
}
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
DeviceCode: authResponse.DeviceCode,
UserCode: authResponse.UserCode,
})
if err != nil {
return "", fmt.Errorf("wait for JWT token: %w", err)
}
if stdout != nil {
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
}
return tokenResponse.Token, nil
}
// VerifyHostKey verifies an SSH host key against stored peer key data.
// Returns nil only if the presented key matches the stored key.
// Returns ErrNoStoredKey if storedKeyData is empty.
// Returns an error if the keys don't match or if parsing fails.
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
if len(storedKeyData) == 0 {
return ErrNoStoredKey
}
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
if err != nil {
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
}
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
}
return nil
}
// AddJWTAuth prepends JWT password authentication to existing auth methods.
// This ensures JWT auth is tried first while preserving any existing auth methods.
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
configWithJWT := *config
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
return &configWithJWT
}
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
// It tries multiple addresses (hostname, IP) for the peer before failing.
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
addresses := buildAddressList(hostname, remote)
presentedKey := key.Marshal()
for _, addr := range addresses {
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
if errors.Is(err, ErrPeerNotFound) {
// Try other addresses for this peer
continue
}
return err
}
// Verified
return nil
}
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
}
}
// buildAddressList creates a list of addresses to check for host key verification.
// It includes the original hostname and extracts the host part from the remote address if different.
func buildAddressList(hostname string, remote net.Addr) []string {
addresses := []string{hostname}
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
if host != hostname {
addresses = append(addresses, host)
}
}
return addresses
}

View File

@@ -1,7 +1,6 @@
package config
import (
"bufio"
"context"
"fmt"
"os"
@@ -12,50 +11,41 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
const (
// EnvDisableSSHConfig is the environment variable to disable SSH config management
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
// EnvForceSSHConfig is the environment variable to force SSH config generation even with many peers
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
// MaxPeersForSSHConfig is the default maximum number of peers before SSH config generation is disabled
MaxPeersForSSHConfig = 200
// fileWriteTimeout is the timeout for file write operations
fileWriteTimeout = 2 * time.Second
)
// isSSHConfigDisabled checks if SSH config management is disabled via environment variable
func isSSHConfigDisabled() bool {
value := os.Getenv(EnvDisableSSHConfig)
if value == "" {
return false
}
// Parse as boolean, default to true if non-empty but invalid
disabled, err := strconv.ParseBool(value)
if err != nil {
// If not a valid boolean, treat any non-empty value as true
return true
}
return disabled
}
// isSSHConfigForced checks if SSH config generation is forced via environment variable
func isSSHConfigForced() bool {
value := os.Getenv(EnvForceSSHConfig)
if value == "" {
return false
}
// Parse as boolean, default to true if non-empty but invalid
forced, err := strconv.ParseBool(value)
if err != nil {
// If not a valid boolean, treat any non-empty value as true
return true
}
return forced
@@ -92,85 +82,55 @@ func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error
}
}
// writeFileOperationWithTimeout performs a file operation with timeout
func writeFileOperationWithTimeout(filename string, operation func() error) error {
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- operation()
}()
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
}
}
// Manager handles SSH client configuration for NetBird peers
type Manager struct {
sshConfigDir string
sshConfigFile string
knownHostsDir string
knownHostsFile string
userKnownHosts string
sshConfigDir string
sshConfigFile string
}
// PeerHostKey represents a peer's SSH host key information
type PeerHostKey struct {
// PeerSSHInfo represents a peer's SSH configuration information
type PeerSSHInfo struct {
Hostname string
IP string
FQDN string
HostKey ssh.PublicKey
}
// NewManager creates a new SSH config manager
func NewManager() *Manager {
sshConfigDir, knownHostsDir := getSystemSSHPaths()
// New creates a new SSH config manager
func New() *Manager {
sshConfigDir := getSystemSSHConfigDir()
return &Manager{
sshConfigDir: sshConfigDir,
sshConfigFile: "99-netbird.conf",
knownHostsDir: knownHostsDir,
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
sshConfigDir: sshConfigDir,
sshConfigFile: nbssh.NetBirdSSHConfigFile,
}
}
// getSystemSSHPaths returns platform-specific SSH configuration paths
func getSystemSSHPaths() (configDir, knownHostsDir string) {
switch runtime.GOOS {
case "windows":
configDir, knownHostsDir = getWindowsSSHPaths()
default:
// Unix-like systems (Linux, macOS, etc.)
configDir = "/etc/ssh/ssh_config.d"
knownHostsDir = "/etc/ssh/ssh_known_hosts.d"
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
func getSystemSSHConfigDir() string {
if runtime.GOOS == "windows" {
return getWindowsSSHConfigDir()
}
return configDir, knownHostsDir
return nbssh.UnixSSHConfigDir
}
func getWindowsSSHPaths() (configDir, knownHostsDir string) {
func getWindowsSSHConfigDir() string {
programData := os.Getenv("PROGRAMDATA")
if programData == "" {
programData = `C:\ProgramData`
}
configDir = filepath.Join(programData, "ssh", "ssh_config.d")
knownHostsDir = filepath.Join(programData, "ssh", "ssh_known_hosts.d")
return configDir, knownHostsDir
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
}
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
func (m *Manager) SetupSSHClientConfig(peerKeys []PeerHostKey) error {
if !shouldGenerateSSHConfig(len(peerKeys)) {
m.logSkipReason(len(peerKeys))
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
if !shouldGenerateSSHConfig(len(peers)) {
m.logSkipReason(len(peers))
return nil
}
knownHostsPath := m.getKnownHostsPath()
sshConfig := m.buildSSHConfig(peerKeys, knownHostsPath)
sshConfig, err := m.buildSSHConfig(peers)
if err != nil {
return fmt.Errorf("build SSH config: %w", err)
}
return m.writeSSHConfig(sshConfig)
}
@@ -183,21 +143,24 @@ func (m *Manager) logSkipReason(peerCount int) {
}
}
func (m *Manager) getKnownHostsPath() string {
knownHostsPath, err := m.setupKnownHostsFile()
if err != nil {
log.Warnf("Failed to setup known_hosts file: %v", err)
return "/dev/null"
}
return knownHostsPath
}
func (m *Manager) buildSSHConfig(peerKeys []PeerHostKey, knownHostsPath string) string {
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
sshConfig := m.buildConfigHeader()
for _, peer := range peerKeys {
sshConfig += m.buildPeerConfig(peer, knownHostsPath)
var allHostPatterns []string
for _, peer := range peers {
hostPatterns := m.buildHostPatterns(peer)
allHostPatterns = append(allHostPatterns, hostPatterns...)
}
return sshConfig
if len(allHostPatterns) > 0 {
peerConfig, err := m.buildPeerConfig(allHostPatterns)
if err != nil {
return "", err
}
sshConfig += peerConfig
}
return sshConfig, nil
}
func (m *Manager) buildConfigHeader() string {
@@ -209,25 +172,49 @@ func (m *Manager) buildConfigHeader() string {
"#\n\n"
}
func (m *Manager) buildPeerConfig(peer PeerHostKey, knownHostsPath string) string {
hostPatterns := m.buildHostPatterns(peer)
if len(hostPatterns) == 0 {
return ""
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
uniquePatterns := make(map[string]bool)
var deduplicatedPatterns []string
for _, pattern := range allHostPatterns {
if !uniquePatterns[pattern] {
uniquePatterns[pattern] = true
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
}
}
hostLine := strings.Join(hostPatterns, " ")
execPath, err := m.getNetBirdExecutablePath()
if err != nil {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
config += " # NetBird peer-specific configuration\n"
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += m.buildHostKeyConfig(knownHostsPath)
config += " LogLevel ERROR\n\n"
return config
if runtime.GOOS == "windows" {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
} else {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
}
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
}
func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
var hostPatterns []string
if peer.IP != "" {
hostPatterns = append(hostPatterns, peer.IP)
@@ -241,280 +228,55 @@ func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
return hostPatterns
}
func (m *Manager) buildHostKeyConfig(knownHostsPath string) string {
if knownHostsPath == "/dev/null" {
return " StrictHostKeyChecking no\n" +
" UserKnownHostsFile /dev/null\n"
}
return " StrictHostKeyChecking yes\n" +
fmt.Sprintf(" UserKnownHostsFile %s\n", knownHostsPath)
}
func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
log.Warnf("Failed to create SSH config directory %s: %v", m.sshConfigDir, err)
return m.setupUserConfig(sshConfig)
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
}
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
log.Warnf("Failed to write SSH config file %s: %v", sshConfigPath, err)
return m.setupUserConfig(sshConfig)
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
return nil
}
// setupUserConfig creates SSH config in user's directory as fallback
func (m *Manager) setupUserConfig(sshConfig string) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("get user home directory: %w", err)
}
userSSHDir := filepath.Join(homeDir, ".ssh")
userConfigPath := filepath.Join(userSSHDir, "config")
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
return fmt.Errorf("create user SSH directory: %w", err)
}
// Check if NetBird config already exists in user config
exists, err := m.configExists(userConfigPath)
if err != nil {
return fmt.Errorf("check existing config: %w", err)
}
if exists {
log.Debugf("NetBird SSH config already exists in %s", userConfigPath)
return nil
}
// Append NetBird config to user's SSH config with timeout
if err := writeFileOperationWithTimeout(userConfigPath, func() error {
file, err := os.OpenFile(userConfigPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return fmt.Errorf("open user SSH config: %w", err)
}
defer func() {
if err := file.Close(); err != nil {
log.Debugf("user SSH config file close error: %v", err)
}
}()
if _, err := fmt.Fprintf(file, "\n%s", sshConfig); err != nil {
return fmt.Errorf("write to user SSH config: %w", err)
}
return nil
}); err != nil {
return err
}
log.Infof("Added NetBird SSH config to user config: %s", userConfigPath)
return nil
}
// configExists checks if NetBird SSH config already exists
func (m *Manager) configExists(configPath string) (bool, error) {
file, err := os.Open(configPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("open SSH config file: %w", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.Contains(line, "NetBird SSH client configuration") {
return true, nil
}
}
return false, scanner.Err()
}
// RemoveSSHClientConfig removes NetBird SSH configuration
func (m *Manager) RemoveSSHClientConfig() error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
// Remove system-wide config if it exists
if err := os.Remove(sshConfigPath); err != nil && !os.IsNotExist(err) {
log.Warnf("Failed to remove system SSH config %s: %v", sshConfigPath, err)
} else if err == nil {
err := os.Remove(sshConfigPath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
}
if err == nil {
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
}
// Also try to clean up user config
homeDir, err := os.UserHomeDir()
if err != nil {
log.Debugf("failed to get user home directory: %v", err)
return nil
}
userConfigPath := filepath.Join(homeDir, ".ssh", "config")
if err := m.removeFromUserConfig(userConfigPath); err != nil {
log.Warnf("Failed to clean user SSH config: %v", err)
}
return nil
}
// removeFromUserConfig removes NetBird section from user's SSH config
func (m *Manager) removeFromUserConfig(configPath string) error {
// This is complex to implement safely, so for now just log
// In practice, the system-wide config takes precedence anyway
log.Debugf("NetBird SSH config cleanup from user config not implemented")
return nil
}
// setupKnownHostsFile creates and returns the path to NetBird known_hosts file
func (m *Manager) setupKnownHostsFile() (string, error) {
// Try system-wide known_hosts first
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
if err := os.MkdirAll(m.knownHostsDir, 0755); err == nil {
// Create empty file if it doesn't exist
if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
if err := writeFileWithTimeout(knownHostsPath, []byte("# NetBird SSH known hosts\n"), 0644); err == nil {
log.Debugf("Created NetBird known_hosts file: %s", knownHostsPath)
return knownHostsPath, nil
}
} else if err == nil {
return knownHostsPath, nil
}
}
// Fallback to user directory
homeDir, err := os.UserHomeDir()
func (m *Manager) getNetBirdExecutablePath() (string, error) {
execPath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("get user home directory: %w", err)
return "", fmt.Errorf("retrieve executable path: %w", err)
}
userSSHDir := filepath.Join(homeDir, ".ssh")
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
return "", fmt.Errorf("create user SSH directory: %w", err)
}
userKnownHostsPath := filepath.Join(userSSHDir, m.userKnownHosts)
if _, err := os.Stat(userKnownHostsPath); os.IsNotExist(err) {
if err := writeFileWithTimeout(userKnownHostsPath, []byte("# NetBird SSH known hosts\n"), 0600); err != nil {
return "", fmt.Errorf("create user known_hosts file: %w", err)
}
log.Debugf("Created NetBird user known_hosts file: %s", userKnownHostsPath)
}
return userKnownHostsPath, nil
}
// UpdatePeerHostKeys updates the known_hosts file with peer host keys
func (m *Manager) UpdatePeerHostKeys(peerKeys []PeerHostKey) error {
peerCount := len(peerKeys)
// Check if SSH config should be generated
if !shouldGenerateSSHConfig(peerCount) {
if isSSHConfigDisabled() {
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
} else {
log.Infof("SSH known_hosts update skipped: too many peers (%d > %d). Use %s=true to force.",
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
}
return nil
}
knownHostsPath, err := m.setupKnownHostsFile()
realPath, err := filepath.EvalSymlinks(execPath)
if err != nil {
return fmt.Errorf("setup known_hosts file: %w", err)
log.Debugf("symlink resolution failed: %v", err)
return execPath, nil
}
// Create updated known_hosts content - NetBird file should only contain NetBird entries
var updatedContent strings.Builder
updatedContent.WriteString("# NetBird SSH known hosts\n")
updatedContent.WriteString("# Generated automatically - do not edit manually\n\n")
// Add new NetBird entries - one entry per peer with all hostnames
for _, peerKey := range peerKeys {
entry := m.formatKnownHostsEntry(peerKey)
updatedContent.WriteString(entry)
updatedContent.WriteString("\n")
}
// Write updated content
if err := writeFileWithTimeout(knownHostsPath, []byte(updatedContent.String()), 0644); err != nil {
return fmt.Errorf("write known_hosts file: %w", err)
}
log.Debugf("Updated NetBird known_hosts with %d peer keys: %s", len(peerKeys), knownHostsPath)
return nil
return realPath, nil
}
// formatKnownHostsEntry formats a peer host key as a known_hosts entry
func (m *Manager) formatKnownHostsEntry(peerKey PeerHostKey) string {
hostnames := m.getHostnameVariants(peerKey)
hostnameList := strings.Join(hostnames, ",")
keyString := string(ssh.MarshalAuthorizedKey(peerKey.HostKey))
keyString = strings.TrimSpace(keyString)
return fmt.Sprintf("%s %s", hostnameList, keyString)
// GetSSHConfigDir returns the SSH config directory path
func (m *Manager) GetSSHConfigDir() string {
return m.sshConfigDir
}
// getHostnameVariants returns all possible hostname variants for a peer
func (m *Manager) getHostnameVariants(peerKey PeerHostKey) []string {
var hostnames []string
// Add IP address
if peerKey.IP != "" {
hostnames = append(hostnames, peerKey.IP)
}
// Add FQDN
if peerKey.FQDN != "" {
hostnames = append(hostnames, peerKey.FQDN)
}
// Add hostname if different from FQDN
if peerKey.Hostname != "" && peerKey.Hostname != peerKey.FQDN {
hostnames = append(hostnames, peerKey.Hostname)
}
// Add bracketed IP for non-standard ports (SSH standard)
if peerKey.IP != "" {
hostnames = append(hostnames, fmt.Sprintf("[%s]:22", peerKey.IP))
hostnames = append(hostnames, fmt.Sprintf("[%s]:22022", peerKey.IP))
}
return hostnames
}
// GetKnownHostsPath returns the path to the NetBird known_hosts file
func (m *Manager) GetKnownHostsPath() (string, error) {
return m.setupKnownHostsFile()
}
// RemoveKnownHostsFile removes the NetBird known_hosts file
func (m *Manager) RemoveKnownHostsFile() error {
// Remove system-wide known_hosts if it exists
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
if err := os.Remove(knownHostsPath); err != nil && !os.IsNotExist(err) {
log.Warnf("Failed to remove system known_hosts %s: %v", knownHostsPath, err)
} else if err == nil {
log.Infof("Removed NetBird known_hosts: %s", knownHostsPath)
}
// Also try to clean up user known_hosts
homeDir, err := os.UserHomeDir()
if err != nil {
log.Debugf("failed to get user home directory: %v", err)
return nil
}
userKnownHostsPath := filepath.Join(homeDir, ".ssh", m.userKnownHosts)
if err := os.Remove(userKnownHostsPath); err != nil && !os.IsNotExist(err) {
log.Warnf("Failed to remove user known_hosts %s: %v", userKnownHostsPath, err)
} else if err == nil {
log.Infof("Removed NetBird user known_hosts: %s", userKnownHostsPath)
}
return nil
// GetSSHConfigFile returns the SSH config file name
func (m *Manager) GetSSHConfigFile() string {
return m.sshConfigFile
}

View File

@@ -10,81 +10,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
func TestManager_UpdatePeerHostKeys(t *testing.T) {
// Create temporary directory for test
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
}
// Generate test host keys
hostKey1, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey1, err := ssh.ParsePrivateKey(hostKey1)
require.NoError(t, err)
hostKey2, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey2, err := ssh.ParsePrivateKey(hostKey2)
require.NoError(t, err)
// Create test peer host keys
peerKeys := []PeerHostKey{
{
Hostname: "peer1",
IP: "100.125.1.1",
FQDN: "peer1.nb.internal",
HostKey: pubKey1.PublicKey(),
},
{
Hostname: "peer2",
IP: "100.125.1.2",
FQDN: "peer2.nb.internal",
HostKey: pubKey2.PublicKey(),
},
}
// Test updating known_hosts
err = manager.UpdatePeerHostKeys(peerKeys)
require.NoError(t, err)
// Verify known_hosts file was created and contains entries
knownHostsPath, err := manager.GetKnownHostsPath()
require.NoError(t, err)
content, err := os.ReadFile(knownHostsPath)
require.NoError(t, err)
contentStr := string(content)
assert.Contains(t, contentStr, "100.125.1.1")
assert.Contains(t, contentStr, "100.125.1.2")
assert.Contains(t, contentStr, "peer1.nb.internal")
assert.Contains(t, contentStr, "peer2.nb.internal")
assert.Contains(t, contentStr, "[100.125.1.1]:22")
assert.Contains(t, contentStr, "[100.125.1.1]:22022")
// Test updating with empty list should preserve structure
err = manager.UpdatePeerHostKeys([]PeerHostKey{})
require.NoError(t, err)
content, err = os.ReadFile(knownHostsPath)
require.NoError(t, err)
assert.Contains(t, string(content), "# NetBird SSH known hosts")
}
func TestManager_SetupSSHClientConfig(t *testing.T) {
// Create temporary directory for test
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
@@ -93,15 +20,25 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
// Test SSH config generation with empty peer keys
err = manager.SetupSSHClientConfig(nil)
// Test SSH config generation with peers
peers := []PeerSSHInfo{
{
Hostname: "peer1",
IP: "100.125.1.1",
FQDN: "peer1.nb.internal",
},
{
Hostname: "peer2",
IP: "100.125.1.2",
FQDN: "peer2.nb.internal",
},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
// Read generated config
@@ -111,134 +48,39 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
configStr := string(content)
// Since we now use per-peer configurations instead of domain patterns,
// we should verify the basic SSH config structure exists
// Verify the basic SSH config structure exists
assert.Contains(t, configStr, "# NetBird SSH client configuration")
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
// Should not contain /dev/null since we have a proper known_hosts setup
assert.NotContains(t, configStr, "UserKnownHostsFile /dev/null")
}
// Check that peer hostnames are included
assert.Contains(t, configStr, "100.125.1.1")
assert.Contains(t, configStr, "100.125.1.2")
assert.Contains(t, configStr, "peer1.nb.internal")
assert.Contains(t, configStr, "peer2.nb.internal")
func TestManager_GetHostnameVariants(t *testing.T) {
manager := NewManager()
peerKey := PeerHostKey{
Hostname: "testpeer",
IP: "100.125.1.10",
FQDN: "testpeer.nb.internal",
HostKey: nil, // Not needed for this test
}
variants := manager.getHostnameVariants(peerKey)
expectedVariants := []string{
"100.125.1.10",
"testpeer.nb.internal",
"testpeer",
"[100.125.1.10]:22",
"[100.125.1.10]:22022",
}
assert.ElementsMatch(t, expectedVariants, variants)
}
func TestManager_FormatKnownHostsEntry(t *testing.T) {
manager := NewManager()
// Generate test key
hostKeyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
parsedKey, err := ssh.ParsePrivateKey(hostKeyPEM)
require.NoError(t, err)
peerKey := PeerHostKey{
Hostname: "testpeer",
IP: "100.125.1.10",
FQDN: "testpeer.nb.internal",
HostKey: parsedKey.PublicKey(),
}
entry := manager.formatKnownHostsEntry(peerKey)
// Should contain all hostname variants
assert.Contains(t, entry, "100.125.1.10")
assert.Contains(t, entry, "testpeer.nb.internal")
assert.Contains(t, entry, "testpeer")
assert.Contains(t, entry, "[100.125.1.10]:22")
assert.Contains(t, entry, "[100.125.1.10]:22022")
// Should contain the public key
keyString := string(ssh.MarshalAuthorizedKey(parsedKey.PublicKey()))
keyString = strings.TrimSpace(keyString)
assert.Contains(t, entry, keyString)
// Should be properly formatted (hostnames followed by key)
parts := strings.Fields(entry)
assert.GreaterOrEqual(t, len(parts), 2, "Entry should have hostnames and key parts")
}
func TestManager_DirectoryFallback(t *testing.T) {
// Create temporary directory for test where system dirs will fail
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
// Set HOME to temp directory to control user fallback
t.Setenv("HOME", tempDir)
// Create manager with non-writable system directories
// Use paths that will fail on all systems
var failPath string
// Check platform-specific UserKnownHostsFile
if runtime.GOOS == "windows" {
failPath = "NUL:" // Special device that can't be used as directory on Windows
assert.Contains(t, configStr, "UserKnownHostsFile NUL")
} else {
failPath = "/dev/null" // Special device that can't be used as directory on Unix
assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
}
manager := &Manager{
sshConfigDir: failPath + "/ssh_config.d", // Should fail
sshConfigFile: "99-netbird.conf",
knownHostsDir: failPath + "/ssh_known_hosts.d", // Should fail
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
}
// Should fall back to user directory
knownHostsPath, err := manager.setupKnownHostsFile()
require.NoError(t, err)
// Get the actual user home directory as determined by os.UserHomeDir()
userHome, err := os.UserHomeDir()
require.NoError(t, err)
expectedUserPath := filepath.Join(userHome, ".ssh", "known_hosts_netbird")
assert.Equal(t, expectedUserPath, knownHostsPath)
// Verify file was created
_, err = os.Stat(knownHostsPath)
require.NoError(t, err)
}
func TestGetSystemSSHPaths(t *testing.T) {
configDir, knownHostsDir := getSystemSSHPaths()
func TestGetSystemSSHConfigDir(t *testing.T) {
configDir := getSystemSSHConfigDir()
// Paths should not be empty
// Path should not be empty
assert.NotEmpty(t, configDir)
assert.NotEmpty(t, knownHostsDir)
// Should be absolute paths
// Should be an absolute path
assert.True(t, filepath.IsAbs(configDir))
assert.True(t, filepath.IsAbs(knownHostsDir))
// On Unix systems, should start with /etc
// On Windows, should contain ProgramData
if runtime.GOOS == "windows" {
assert.Contains(t, strings.ToLower(configDir), "programdata")
assert.Contains(t, strings.ToLower(knownHostsDir), "programdata")
} else {
assert.Contains(t, configDir, "/etc/ssh")
assert.Contains(t, knownHostsDir, "/etc/ssh")
}
}
@@ -250,46 +92,28 @@ func TestManager_PeerLimit(t *testing.T) {
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
// Generate many peer keys (more than limit)
var peerKeys []PeerHostKey
// Generate many peers (more than limit)
var peers []PeerSSHInfo
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey, err := ssh.ParsePrivateKey(hostKey)
require.NoError(t, err)
peerKeys = append(peerKeys, PeerHostKey{
peers = append(peers, PeerSSHInfo{
Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
HostKey: pubKey.PublicKey(),
})
}
// Test that SSH config generation is skipped when too many peers
err = manager.SetupSSHClientConfig(peerKeys)
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
// Config should not be created due to peer limit
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
_, err = os.Stat(configPath)
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
// Test that known_hosts update is also skipped
err = manager.UpdatePeerHostKeys(peerKeys)
require.NoError(t, err)
// Known hosts should not be created due to peer limit
knownHostsPath := filepath.Join(manager.knownHostsDir, manager.knownHostsFile)
_, err = os.Stat(knownHostsPath)
assert.True(t, os.IsNotExist(err), "Known hosts should not be created with too many peers")
}
func TestManager_ForcedSSHConfig(t *testing.T) {
@@ -303,31 +127,22 @@ func TestManager_ForcedSSHConfig(t *testing.T) {
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
// Generate many peer keys (more than limit)
var peerKeys []PeerHostKey
// Generate many peers (more than limit)
var peers []PeerSSHInfo
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey, err := ssh.ParsePrivateKey(hostKey)
require.NoError(t, err)
peerKeys = append(peerKeys, PeerHostKey{
peers = append(peers, PeerSSHInfo{
Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
HostKey: pubKey.PublicKey(),
})
}
// Test that SSH config generation is forced despite many peers
err = manager.SetupSSHClientConfig(peerKeys)
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
// Config should be created despite peer limit due to force flag

View File

@@ -0,0 +1,22 @@
package config
// ShutdownState represents SSH configuration state that needs to be cleaned up.
type ShutdownState struct {
SSHConfigDir string
SSHConfigFile string
}
// Name returns the state name for the state manager.
func (s *ShutdownState) Name() string {
return "ssh_config_state"
}
// Cleanup removes SSH client configuration files.
func (s *ShutdownState) Cleanup() error {
manager := &Manager{
sshConfigDir: s.SSHConfigDir,
sshConfigFile: s.SSHConfigFile,
}
return manager.RemoveSSHClientConfig()
}

View File

@@ -0,0 +1,99 @@
package detection
import (
"bufio"
"context"
"net"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const (
// ServerIdentifier is the base response for NetBird SSH servers
ServerIdentifier = "NetBird-SSH-Server"
// ProxyIdentifier is the base response for NetBird SSH proxy
ProxyIdentifier = "NetBird-SSH-Proxy"
// JWTRequiredMarker is appended to responses when JWT is required
JWTRequiredMarker = "NetBird-JWT-Required"
// Timeout is the timeout for SSH server detection
Timeout = 5 * time.Second
)
type ServerType string
const (
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
ServerTypeRegular ServerType = "regular"
)
// Dialer provides network connection capabilities
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// RequiresJWT checks if the server type requires JWT authentication
func (s ServerType) RequiresJWT() bool {
return s == ServerTypeNetBirdJWT
}
// ExitCode returns the exit code for the detect command
func (s ServerType) ExitCode() int {
switch s {
case ServerTypeNetBirdJWT:
return 0
case ServerTypeNetBirdNoJWT:
return 1
case ServerTypeRegular:
return 2
default:
return 2
}
}
// DetectSSHServerType detects SSH server type using the provided dialer
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
if err != nil {
log.Debugf("SSH connection failed for detection: %v", err)
return ServerTypeRegular, nil
}
defer conn.Close()
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
log.Debugf("set read deadline: %v", err)
return ServerTypeRegular, nil
}
reader := bufio.NewReader(conn)
serverBanner, err := reader.ReadString('\n')
if err != nil {
log.Debugf("read SSH banner: %v", err)
return ServerTypeRegular, nil
}
serverBanner = strings.TrimSpace(serverBanner)
log.Debugf("SSH server banner: %s", serverBanner)
if !strings.HasPrefix(serverBanner, "SSH-") {
log.Debugf("Invalid SSH banner")
return ServerTypeRegular, nil
}
if !strings.Contains(serverBanner, ServerIdentifier) {
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
return ServerTypeRegular, nil
}
if strings.Contains(serverBanner, JWTRequiredMarker) {
return ServerTypeNetBirdJWT, nil
}
return ServerTypeNetBirdNoJWT, nil
}

359
client/ssh/proxy/proxy.go Normal file
View File

@@ -0,0 +1,359 @@
package proxy
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/version"
)
const (
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
sshConnectionTimeout = 120 * time.Second
// sshHandshakeTimeout is the timeout for SSH handshake completion
sshHandshakeTimeout = 30 * time.Second
jwtAuthErrorMsg = "JWT authentication: %w"
)
type SSHProxy struct {
daemonAddr string
targetHost string
targetPort int
stderr io.Writer
daemonClient proto.DaemonServiceClient
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, fmt.Errorf("connect to daemon: %w", err)
}
return &SSHProxy{
daemonAddr: daemonAddr,
targetHost: targetHost,
targetPort: targetPort,
stderr: stderr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}, nil
}
func (p *SSHProxy) Connect(ctx context.Context) error {
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true)
if err != nil {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
return p.runProxySSHServer(ctx, jwtToken)
}
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{
Handler: func(s ssh.Session) {
p.handleSSHSession(ctx, s, jwtToken)
},
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler,
},
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": func(s ssh.Session) {
p.sftpSubsystemHandler(s, jwtToken)
},
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": p.tcpipForwardHandler,
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
},
Version: serverVersion,
}
hostKey, err := generateHostKey()
if err != nil {
return fmt.Errorf("generate host key: %w", err)
}
sshServer.HostSigners = []ssh.Signer{hostKey}
conn := &stdioConn{
stdin: os.Stdin,
stdout: os.Stdout,
}
sshServer.HandleConn(conn)
return nil
}
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return
}
defer func() { _ = sshClient.Close() }()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
return
}
defer func() { _ = serverSession.Close() }()
serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
ptyReq, winCh, isPty := session.Pty()
if isPty {
_ = serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil)
go func() {
for win := range winCh {
_ = serverSession.WindowChange(win.Height, win.Width)
}
}()
}
if len(session.Command()) > 0 {
_ = serverSession.Run(strings.Join(session.Command(), " "))
return
}
if err = serverSession.Shell(); err == nil {
_ = serverSession.Wait()
}
}
func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
return nil, fmt.Errorf("generate ED25519 key: %w", err)
}
signer, err := cryptossh.ParsePrivateKey(keyPEM)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
return signer, nil
}
type stdioConn struct {
stdin io.Reader
stdout io.Writer
closed bool
mu sync.Mutex
}
func (c *stdioConn) Read(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.EOF
}
c.mu.Unlock()
return c.stdin.Read(b)
}
func (c *stdioConn) Write(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.ErrClosedPipe
}
c.mu.Unlock()
return c.stdout.Write(b)
}
func (c *stdioConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
func (c *stdioConn) LocalAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) SetDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
ctx, cancel := context.WithCancel(s.Context())
defer cancel()
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := sshClient.Close(); err != nil {
log.Debugf("close SSH client: %v", err)
}
}()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := serverSession.Close(); err != nil {
log.Debugf("close server session: %v", err)
}
}()
stdin, stdout, err := p.setupSFTPPipes(serverSession)
if err != nil {
log.Debugf("setup SFTP pipes: %v", err)
_ = s.Exit(1)
return
}
if err := serverSession.RequestSubsystem("sftp"); err != nil {
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
_ = s.Exit(1)
return
}
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
}
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
stdin, err := serverSession.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
}
stdout, err := serverSession.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
}
return stdin, stdout, nil
}
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
copyErrCh := make(chan error, 2)
go func() {
_, err := io.Copy(stdin, s)
if err != nil {
log.Debugf("SFTP client to server copy: %v", err)
}
if err := stdin.Close(); err != nil {
log.Debugf("close stdin: %v", err)
}
copyErrCh <- err
}()
go func() {
_, err := io.Copy(s, stdout)
if err != nil {
log.Debugf("SFTP server to client copy: %v", err)
}
copyErrCh <- err
}()
go func() {
<-ctx.Done()
if err := serverSession.Close(); err != nil {
log.Debugf("force close server session on context cancellation: %v", err)
}
}()
for i := 0; i < 2; i++ {
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
log.Debugf("SFTP copy error: %v", err)
}
}
if err := serverSession.Wait(); err != nil {
log.Debugf("SFTP session ended: %v", err)
}
}
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return false, []byte("port forwarding not supported in proxy")
}
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return true, nil
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
config := &cryptossh.ClientConfig{
User: user,
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
Timeout: sshHandshakeTimeout,
HostKeyCallback: p.verifyHostKey,
}
dialer := &net.Dialer{
Timeout: sshConnectionTimeout,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("connect to server: %w", err)
}
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("SSH handshake: %w", err)
}
return cryptossh.NewClient(clientConn, chans, reqs), nil
}
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
callback := nbssh.CreateHostKeyCallback(verifier)
return callback(hostname, remote, key)
}

View File

@@ -0,0 +1,361 @@
package proxy
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
)
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "ssh" {
if os.Args[2] == "exec" {
if len(os.Args) > 3 {
cmd := os.Args[3]
if cmd == "echo" && len(os.Args) > 4 {
fmt.Fprintln(os.Stdout, os.Args[4])
os.Exit(0)
}
}
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
os.Exit(1)
}
}
code := m.Run()
testutil.CleanupTestUsers()
os.Exit(code)
}
func TestSSHProxy_verifyHostKey(t *testing.T) {
t.Run("calls daemon to verify host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testPubKey, err := nbssh.GeneratePublicKey(testKey)
require.NoError(t, err)
mockDaemon.setHostKey("test-host", testPubKey)
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
assert.NoError(t, err)
})
t.Run("rejects unknown host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
require.NoError(t, err)
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
assert.Error(t, err)
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
})
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
jwtToken string
}
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
key, found := m.hostKeys[req.PeerAddress]
return &proto.GetPeerSSHHostKeyResponse{
Found: found,
SshHostKey: key,
}, nil
}
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
return &proto.RequestJWTAuthResponse{
CachedToken: m.jwtToken,
}, nil
}
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
return &proto.WaitJWTTokenResponse{
Token: m.jwtToken,
}, nil
}
type mockDaemon struct {
addr string
server *grpc.Server
impl *mockDaemonServer
}
func startMockDaemon(t *testing.T) *mockDaemon {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
impl := &mockDaemonServer{
hostKeys: make(map[string][]byte),
jwtToken: "test-jwt-token",
}
grpcServer := grpc.NewServer()
proto.RegisterDaemonServiceServer(grpcServer, impl)
go func() {
_ = grpcServer.Serve(listener)
}()
return &mockDaemon{
addr: listener.Addr().String(),
server: grpcServer,
impl: impl,
}
}
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
m.impl.hostKeys[addr] = pubKey
}
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func (m *mockDaemon) stop() {
if m.server != nil {
m.server.Stop()
}
}
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
t.Helper()
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
require.NoError(t, err)
return pubKey
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -9,7 +9,6 @@ import (
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"testing"
@@ -21,15 +20,24 @@ import (
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
cleanupTestUsers()
testutil.CleanupTestUsers()
os.Exit(code)
}
@@ -50,13 +58,15 @@ func TestSSHServerCompatibility(t *testing.T) {
require.NoError(t, err)
// Generate OpenSSH-compatible keys for client
clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey(t)
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKeyOpenSSH))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -73,7 +83,7 @@ func TestSSHServerCompatibility(t *testing.T) {
require.NoError(t, err)
// Get appropriate user for SSH connection (handle system accounts)
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
t.Run("basic command execution", func(t *testing.T) {
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
@@ -113,7 +123,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username
// testSSHInteractiveCommand tests interactive shell session.
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -178,7 +188,7 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// testSSHPortForwarding tests port forwarding compatibility.
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
testServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@@ -401,7 +411,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
t.Skip("Skipping SSH feature compatibility tests in short mode")
}
if runtime.GOOS == "windows" && isCI() {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
}
@@ -438,13 +448,13 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -468,7 +478,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
// testCommandWithFlags tests that commands with flags work properly
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Test ls with flags
cmd := exec.Command("ssh",
@@ -495,7 +505,7 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// testEnvironmentVariables tests that environment is properly set up
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
cmd := exec.Command("ssh",
"-i", keyFile,
@@ -522,7 +532,7 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// testExitCodes tests that exit codes are properly handled
func testExitCodes(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Test successful command (exit code 0)
cmd := exec.Command("ssh",
@@ -567,7 +577,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
}
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Set up SSH server with specific security settings
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
@@ -575,13 +585,13 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -652,7 +662,7 @@ func TestCrossPlatformCompatibility(t *testing.T) {
}
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
@@ -660,13 +670,13 @@ func TestCrossPlatformCompatibility(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -710,171 +720,3 @@ func TestCrossPlatformCompatibility(t *testing.T) {
t.Logf("Platform command output: %s", outputStr)
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
}
// getTestUsername returns an appropriate username for testing
func getTestUsername(t *testing.T) string {
if runtime.GOOS == "windows" {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Check if this is a system account that can't authenticate
if isSystemAccount(currentUser.Username) {
// In CI environments, create a test user; otherwise try Administrator
if isCI() {
if testUser := getOrCreateTestUser(t); testUser != "" {
return testUser
}
} else {
// Try Administrator first for local development
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := getOrCreateTestUser(t); testUser != "" {
return testUser
}
}
}
return currentUser.Username
}
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
return currentUser.Username
}
// isCI checks if we're running in a CI environment
func isCI() bool {
// Check standard CI environment variables
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// isSystemAccount checks if the user is a system account that can't authenticate
func isSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
var compatTestCreatedUsers = make(map[string]bool)
var compatTestUsersToCleanup []string
// registerTestUserCleanup registers a test user for cleanup
func registerTestUserCleanup(username string) {
if !compatTestCreatedUsers[username] {
compatTestCreatedUsers[username] = true
compatTestUsersToCleanup = append(compatTestUsersToCleanup, username)
}
}
// cleanupTestUsers removes all created test users
func cleanupTestUsers() {
for _, username := range compatTestUsersToCleanup {
removeWindowsTestUser(username)
}
compatTestUsersToCleanup = nil
compatTestCreatedUsers = make(map[string]bool)
}
// getOrCreateTestUser creates a test user on Windows if needed
func getOrCreateTestUser(t *testing.T) string {
testUsername := "netbird-test-user"
// Check if user already exists
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
// Try to create the user using PowerShell
if createWindowsTestUser(t, testUsername) {
// Register cleanup for the test user
registerTestUserCleanup(testUsername)
return testUsername
}
return ""
}
// removeWindowsTestUser removes a local user on Windows using PowerShell
func removeWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
// PowerShell command to remove a local user
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// createWindowsTestUser creates a local user on Windows using PowerShell
func createWindowsTestUser(t *testing.T, username string) bool {
if runtime.GOOS != "windows" {
return false
}
// PowerShell command to create a local user
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
t.Logf("Test user creation result: %s", string(output))
return true
}

View File

@@ -0,0 +1,610 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
cryptossh "golang.org/x/crypto/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
)
func TestJWTEnforcement(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT enforcement tests in short mode")
}
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
if err != nil {
t.Logf("Detection failed: %v", err)
}
t.Logf("Detected server type: %s", serverType)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
})
t.Run("allows_when_disabled", func(t *testing.T) {
serverConfigNoJWT := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
serverNoJWT := New(serverConfigNoJWT)
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
serverNoJWT.SetAllowRootLogin(true)
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
defer require.NoError(t, serverNoJWT.Stop())
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
require.NoError(t, err)
portNoJWT, err := strconv.Atoi(portStrNoJWT)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
assert.False(t, serverType.RequiresJWT())
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
require.NoError(t, err)
defer client.Close()
})
}
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64RawURLEncode(n),
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func base64RawURLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
// generateValidJWT creates a valid JWT token for testing
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
t.Helper()
addr := net.JoinHostPort(host, strconv.Itoa(port))
ctx := context.Background()
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
InsecureSkipVerify: true,
})
}
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
func TestJWTDetection(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT detection test in short mode")
}
jwksServer, _, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
assert.True(t, serverType.RequiresJWT())
}
func TestJWTFailClose(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT fail-close tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
tokenClaims jwt.MapClaims
}{
{
name: "blocks_token_missing_iat",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
},
},
{
name: "blocks_token_missing_sub",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_iss",
tokenClaims: jwt.MapClaims{
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_aud",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_issuer",
tokenClaims: jwt.MapClaims{
"iss": "wrong-issuer",
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_audience",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": "wrong-audience",
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_expired_token",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(-time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
MaxTokenAge: 3600,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(tokenString),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
assert.Error(t, err, "Authentication should fail (fail-close)")
})
}
}
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
func TestJWTAuthentication(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT authentication tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
token string
wantAuthOK bool
setupServer func(*Server)
testOperation func(*testing.T, *cryptossh.Client, string) error
wantOpSuccess bool
}{
{
name: "allows_shell_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
return session.Shell()
},
wantOpSuccess: true,
},
{
name: "rejects_invalid_token",
token: "invalid",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_shell_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_command_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("ls")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "allows_sftp_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
return session.RequestSubsystem("sftp")
},
wantOpSuccess: true,
},
{
name: "blocks_sftp_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
err = session.RequestSubsystem("sftp")
if err == nil {
err = session.Wait()
}
return err
},
wantOpSuccess: false,
},
{
name: "allows_port_forward_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowRemotePortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: true,
},
{
name: "blocks_port_forward_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowLocalPortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
if tc.setupServer != nil {
tc.setupServer(server)
}
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
var authMethods []cryptossh.AuthMethod
if tc.token == "valid" {
token := generateValidJWT(t, privateKey, issuer, audience)
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(token),
}
} else if tc.token == "invalid" {
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(invalidToken),
}
}
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: authMethods,
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed")
} else if err != nil {
t.Logf("Connection failed as expected: %v", err)
return
}
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
err = tc.testOperation(t, conn, serverAddr)
if tc.wantOpSuccess {
require.NoError(t, err, "Operation should succeed")
} else {
assert.Error(t, err, "Operation should fail")
}
})
}
}

View File

@@ -2,18 +2,27 @@ package server
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/version"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
@@ -27,6 +36,9 @@ const (
errExitSession = "exit session error: %v"
msgPrivilegedUserDisabled = "privileged user login is disabled"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
var (
@@ -69,7 +81,6 @@ func (e *UserNotFoundError) Unwrap() error {
}
// safeLogCommand returns a safe representation of the command for logging
// Only logs the first argument to avoid leaking sensitive information
func safeLogCommand(cmd []string) string {
if len(cmd) == 0 {
return "<empty>"
@@ -80,17 +91,14 @@ func safeLogCommand(cmd []string) string {
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
}
// sshConnectionState tracks the state of an SSH connection
type sshConnectionState struct {
hasActivePortForward bool
username string
remoteAddr string
}
// Server is the SSH server implementation
type Server struct {
sshServer *ssh.Server
authorizedKeys map[string]ssh.PublicKey
mu sync.RWMutex
hostKeyPEM []byte
sessions map[SessionKey]ssh.Session
@@ -100,30 +108,53 @@ type Server struct {
allowRemotePortForwarding bool
allowRootLogin bool
allowSFTP bool
jwtEnabled bool
netstackNet *netstack.Net
wgAddress wgaddr.Address
ifIdx int
remoteForwardListeners map[ForwardKey]net.Listener
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
}
// New creates an SSH server instance with the provided host key
func New(hostKeyPEM []byte) *Server {
return &Server{
type JWTConfig struct {
Issuer string
Audience string
KeysLocation string
MaxTokenAge int64
}
// Config contains all SSH server configuration options
type Config struct {
// JWT authentication configuration. If nil, JWT authentication is disabled
JWT *JWTConfig
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: hostKeyPEM,
authorizedKeys: make(map[string]ssh.PublicKey),
hostKeyPEM: config.HostKeyPEM,
sessions: make(map[SessionKey]ssh.Session),
remoteForwardListeners: make(map[ForwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
}
return s
}
// Start runs the SSH server, automatically detecting netstack vs standard networking
// Does all setup synchronously, then starts serving in a goroutine and returns immediately
// Start runs the SSH server
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -139,7 +170,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
sshServer, err := s.createSSHServer(ln.Addr())
if err != nil {
s.cleanupOnError(ln)
s.closeListener(ln)
return fmt.Errorf("create SSH server: %w", err)
}
@@ -154,7 +185,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return nil
}
// createListener creates a network listener based on netstack vs standard networking
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
@@ -173,22 +203,15 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L
return ln, addr.String(), nil
}
// closeListener safely closes a listener
func (s *Server) closeListener(ln net.Listener) {
if ln == nil {
return
}
if err := ln.Close(); err != nil {
log.Debugf("listener close error: %v", err)
}
}
// cleanupOnError cleans up resources when SSH server creation fails
func (s *Server) cleanupOnError(ln net.Listener) {
if s.ifIdx == 0 || ln == nil {
return
}
s.closeListener(ln)
}
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
@@ -207,28 +230,6 @@ func (s *Server) Stop() error {
return nil
}
// RemoveAuthorizedKey removes the SSH key for a peer
func (s *Server) RemoveAuthorizedKey(peer string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.authorizedKeys, peer)
}
// AddAuthorizedKey adds an SSH key for a peer
func (s *Server) AddAuthorizedKey(peer, newKey string) error {
s.mu.Lock()
defer s.mu.Unlock()
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
if err != nil {
return fmt.Errorf("parse key: %w", err)
}
s.authorizedKeys[peer] = parsedKey
return nil
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
@@ -243,34 +244,195 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr
}
// SetSocketFilter configures eBPF socket filtering for the SSH server
func (s *Server) SetSocketFilter(ifIdx int) {
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
if s.jwtValidator != nil && s.jwtExtractor != nil {
s.mu.RUnlock()
return nil
}
config := s.jwtConfig
s.mu.RUnlock()
if config == nil {
return fmt.Errorf("JWT config not set")
}
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
validator := jwt.NewValidator(
config.Issuer,
[]string{config.Audience},
config.KeysLocation,
true,
)
extractor := jwt.NewClaimsExtractor(
jwt.WithAudience(config.Audience),
)
s.mu.Lock()
defer s.mu.Unlock()
s.ifIdx = ifIdx
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
s.jwtValidator = validator
s.jwtExtractor = extractor
log.Infof("JWT validator initialized successfully")
return nil
}
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
s.mu.RLock()
defer s.mu.RUnlock()
jwtValidator := s.jwtValidator
jwtConfig := s.jwtConfig
s.mu.RUnlock()
for _, allowed := range s.authorizedKeys {
if ssh.KeysEqual(allowed, key) {
if ctx != nil {
log.Debugf("SSH key authentication successful for user %s from %s", ctx.User(), ctx.RemoteAddr())
if jwtValidator == nil {
return nil, fmt.Errorf("JWT validator not initialized")
}
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
if err != nil {
if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
}
return true
}
return nil, fmt.Errorf("validate token: %w", err)
}
if ctx != nil {
log.Warnf("SSH key authentication failed for user %s from %s: key not authorized (type: %s, fingerprint: %s)",
ctx.User(), ctx.RemoteAddr(), key.Type(), cryptossh.FingerprintSHA256(key))
if err := s.checkTokenAge(token, jwtConfig); err != nil {
return nil, err
}
return false
return token, nil
}
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
return nil
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
s.mu.RLock()
jwtExtractor := s.jwtExtractor
s.mu.RUnlock()
if jwtExtractor == nil {
userID := extractUserID(token)
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
}
userAuth, err := jwtExtractor.ToUserAuth(token)
if err != nil {
userID := extractUserID(token)
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
}
if !s.hasSSHAccess(&userAuth) {
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
}
return &userAuth, nil
}
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
return userAuth.UserId != ""
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("decode payload: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("parse claims: %w", err)
}
return claims, nil
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true
}
// markConnectionActivePortForward marks an SSH connection as having an active port forward
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -286,14 +448,12 @@ func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn,
}
}
// connectionCloseHandler cleans up connection state when SSH connections fail/close
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
// We can't extract the SSH connection from net.Conn directly
// Connection cleanup will happen during session cleanup or via timeout
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
}
// findSessionKeyByContext finds the session key by matching SSH connection context
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
if ctx == nil {
return "unknown"
@@ -319,14 +479,13 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
// Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil {
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
log.Debugf("using temporary session key for port forward tracking: %s", tempKey)
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
return tempKey
}
return "unknown"
}
// connectionValidator validates incoming connections based on source IP
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
s.mu.RLock()
netbirdNetwork := s.wgAddress.Network
@@ -340,8 +499,8 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
remoteAddr := conn.RemoteAddr()
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
if !ok {
log.Debugf("SSH connection from non-TCP address %s allowed", remoteAddr)
return conn
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
return nil
}
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
@@ -357,15 +516,14 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
}
if !netbirdNetwork.Contains(remoteIP) {
log.Warnf("SSH connection rejected from non-NetBird IP %s (allowed range: %s)", remoteIP, netbirdNetwork)
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
return nil
}
log.Debugf("SSH connection from %s allowed", remoteIP)
log.Infof("SSH connection from NetBird peer %s allowed", remoteIP)
return conn
}
// isShutdownError checks if the error is expected during normal shutdown
func isShutdownError(err error) bool {
if errors.Is(err, net.ErrClosed) {
return true
@@ -379,12 +537,16 @@ func isShutdownError(err error) bool {
return false
}
// createSSHServer creates and configures the SSH server
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
if err := enableUserSwitching(); err != nil {
log.Warnf("failed to enable user switching: %v", err)
}
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
if s.jwtEnabled {
serverVersion += " " + detection.JWTRequiredMarker
}
server := &ssh.Server{
Addr: addr.String(),
Handler: s.sessionHandler,
@@ -402,6 +564,11 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
},
ConnCallback: s.connectionValidator,
ConnectionFailedCallback: s.connectionCloseHandler,
Version: serverVersion,
}
if s.jwtEnabled {
server.PasswordHandler = s.passwordHandler
}
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
@@ -413,14 +580,12 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
return server, nil
}
// storeRemoteForwardListener stores a remote forward listener for cleanup
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln
}
// removeRemoteForwardListener removes and closes a remote forward listener
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
s.mu.Lock()
defer s.mu.Unlock()
@@ -438,7 +603,6 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
return true
}
// directTCPIPHandler handles direct-tcpip channel requests for local port forwarding with privilege validation
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
var payload struct {
Host string

View File

@@ -22,12 +22,6 @@ func TestServer_RootLoginRestriction(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
tests := []struct {
name string
allowRoot bool
@@ -117,10 +111,12 @@ func TestServer_RootLoginRestriction(t *testing.T) {
defer cleanup()
// Create server with specific configuration
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(tt.allowRoot)
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
// Test the userNameLookup method directly
user, err := server.userNameLookup(tt.username)
@@ -196,7 +192,11 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create server with specific configuration
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
@@ -234,17 +234,13 @@ func TestServer_PortConflictHandling(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -263,7 +259,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel1()
client1, err := sshclient.DialInsecure(ctx1, serverAddr, currentUser.Username)
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client1.Close()
@@ -274,7 +272,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
client2, err := sshclient.DialInsecure(ctx2, serverAddr, currentUser.Username)
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client2.Close()

View File

@@ -7,11 +7,9 @@ import (
"net/netip"
"os/user"
"runtime"
"strings"
"testing"
"time"
"github.com/gliderlabs/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
@@ -19,82 +17,15 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
)
func TestServer_AddAuthorizedKey(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
keys := map[string][]byte{}
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey(peer, string(remotePubKey))
require.NoError(t, err)
keys[peer] = remotePubKey
}
for peer, remotePubKey := range keys {
k, ok := server.authorizedKeys[peer]
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(cryptossh.MarshalAuthorizedKey(k))))
}
}
func TestServer_RemoveAuthorizedKey(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
require.NoError(t, err)
server.RemoveAuthorizedKey("remotePeer")
_, ok := server.authorizedKeys["remotePeer"]
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
}
func TestServer_PubKeyHandler(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
var keys []ssh.PublicKey
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
require.NoError(t, err)
err = server.AddAuthorizedKey(peer, string(remotePubKey))
require.NoError(t, err)
keys = append(keys, remoteParsedPubKey)
}
for _, key := range keys {
accepted := server.publicKeyHandler(nil, key)
assert.True(t, accepted, "SSH key should be accepted")
}
}
func TestServer_StartStop(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
serverConfig := &Config{
HostKeyPEM: key,
JWT: nil,
}
server := New(serverConfig)
err = server.Stop()
assert.NoError(t, err)
@@ -108,15 +39,13 @@ func TestSSHServerIntegration(t *testing.T) {
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with random port
server := New(hostKey)
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server in background
serverAddr := "127.0.0.1:0"
@@ -212,13 +141,13 @@ func TestSSHServerMultipleConnections(t *testing.T) {
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server
server := New(hostKey)
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
@@ -324,20 +253,12 @@ func TestSSHServerNoAuthMode(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate authorized key
authorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
authorizedPubKey, err := nbssh.GeneratePublicKey(authorizedPrivKey)
require.NoError(t, err)
// Generate unauthorized key (different from authorized)
unauthorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server with only one authorized key
server := New(hostKey)
err = server.AddAuthorizedKey("authorized-peer", string(authorizedPubKey))
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
@@ -377,8 +298,10 @@ func TestSSHServerNoAuthMode(t *testing.T) {
require.NoError(t, err)
}()
// Parse unauthorized private key
unauthorizedSigner, err := cryptossh.ParsePrivateKey(unauthorizedPrivKey)
// Generate a client private key for SSH protocol (server doesn't check it)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
@@ -390,17 +313,17 @@ func TestSSHServerNoAuthMode(t *testing.T) {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
// Try to connect with unauthorized key
// Try to connect with client key
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(unauthorizedSigner),
cryptossh.PublicKeys(clientSigner),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// This should succeed in no-auth mode
// This should succeed in no-auth mode (server doesn't verify keys)
conn, err := cryptossh.Dial("tcp", serverAddr, config)
assert.NoError(t, err, "Connection should succeed in no-auth mode")
if conn != nil {
@@ -412,7 +335,11 @@ func TestSSHServerStartStopCycle(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
serverAddr := "127.0.0.1:0"
// Test multiple start/stop cycles
@@ -485,8 +412,17 @@ func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server1 := New(hostKey)
server2 := New(hostKey)
serverConfig1 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server1 := New(serverConfig1)
serverConfig2 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server2 := New(serverConfig2)
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")

View File

@@ -35,17 +35,15 @@ func TestSSHServer_SFTPSubsystem(t *testing.T) {
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with SFTP enabled
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(true)
server.SetAllowRootLogin(true) // Allow root login for testing
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
server.SetAllowRootLogin(true)
// Start server
serverAddr := "127.0.0.1:0"
@@ -144,17 +142,15 @@ func TestSSHServer_SFTPDisabled(t *testing.T) {
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with SFTP disabled
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(false)
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)

View File

@@ -14,7 +14,6 @@ func StartTestServer(t *testing.T, server *Server) string {
errChan := make(chan error, 1)
go func() {
// Get a free port
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
errChan <- err
@@ -26,9 +25,12 @@ func StartTestServer(t *testing.T, server *Server) string {
return
}
started <- actualAddr
addrPort := netip.MustParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {

View File

@@ -0,0 +1,172 @@
package testutil
import (
"fmt"
"log"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
var testCreatedUsers = make(map[string]bool)
var testUsersToCleanup []string
// GetTestUsername returns an appropriate username for testing
func GetTestUsername(t *testing.T) string {
if runtime.GOOS == "windows" {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
if IsSystemAccount(currentUser.Username) {
if IsCI() {
if testUser := GetOrCreateTestUser(t); testUser != "" {
return testUser
}
} else {
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := GetOrCreateTestUser(t); testUser != "" {
return testUser
}
}
}
return currentUser.Username
}
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
return currentUser.Username
}
// IsCI checks if we're running in a CI environment
func IsCI() bool {
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// IsSystemAccount checks if the user is a system account that can't authenticate
func IsSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
// RegisterTestUserCleanup registers a test user for cleanup
func RegisterTestUserCleanup(username string) {
if !testCreatedUsers[username] {
testCreatedUsers[username] = true
testUsersToCleanup = append(testUsersToCleanup, username)
}
}
// CleanupTestUsers removes all created test users
func CleanupTestUsers() {
for _, username := range testUsersToCleanup {
RemoveWindowsTestUser(username)
}
testUsersToCleanup = nil
testCreatedUsers = make(map[string]bool)
}
// GetOrCreateTestUser creates a test user on Windows if needed
func GetOrCreateTestUser(t *testing.T) string {
testUsername := "netbird-test-user"
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
if CreateWindowsTestUser(t, testUsername) {
RegisterTestUserCleanup(testUsername)
return testUsername
}
return ""
}
// RemoveWindowsTestUser removes a local user on Windows using PowerShell
func RemoveWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// CreateWindowsTestUser creates a local user on Windows using PowerShell
func CreateWindowsTestUser(t *testing.T, username string) bool {
if runtime.GOOS != "windows" {
return false
}
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
t.Logf("Test user creation result: %s", string(output))
return true
}