Files
netbird/client/ssh/server/sftp_test.go
2025-11-17 17:10:41 +01:00

229 lines
5.7 KiB
Go

package server
import (
"context"
"fmt"
"net"
"net/netip"
"os"
"os/user"
"testing"
"time"
"github.com/pkg/sftp"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
)
func TestSSHServer_SFTPSubsystem(t *testing.T) {
// Skip SFTP test when running as root due to protocol issues in some environments
if os.Geteuid() == 0 {
t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues")
}
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Create server with SFTP enabled
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(true)
server.SetAllowRootLogin(true)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// (currentUser already obtained at function start)
// Create SSH client connection
clientConfig := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 5 * time.Second,
}
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
require.NoError(t, err, "SSH connection should succeed")
defer func() {
if err := conn.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
// Create SFTP client
sftpClient, err := sftp.NewClient(conn)
require.NoError(t, err, "SFTP client creation should succeed")
defer func() {
if err := sftpClient.Close(); err != nil {
t.Logf("SFTP client close error: %v", err)
}
}()
// Test basic SFTP operations
workingDir, err := sftpClient.Getwd()
assert.NoError(t, err, "Should be able to get working directory")
assert.NotEmpty(t, workingDir, "Working directory should not be empty")
// Test directory listing
files, err := sftpClient.ReadDir(".")
assert.NoError(t, err, "Should be able to list current directory")
assert.NotNil(t, files, "File list should not be nil")
}
func TestSSHServer_SFTPDisabled(t *testing.T) {
// Get current user for SSH connection
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Generate host key for server
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Create server with SFTP disabled
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(false)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)
errChan := make(chan error, 1)
go func() {
ln, err := net.Listen("tcp", serverAddr)
if err != nil {
errChan <- err
return
}
actualAddr := ln.Addr().String()
if err := ln.Close(); err != nil {
errChan <- fmt.Errorf("close temp listener: %w", err)
return
}
addrPort, _ := netip.ParseAddrPort(actualAddr)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {
case actualAddr := <-started:
serverAddr = actualAddr
case err := <-errChan:
t.Fatalf("Server failed to start: %v", err)
case <-time.After(5 * time.Second):
t.Fatal("Server start timeout")
}
defer func() {
err := server.Stop()
require.NoError(t, err)
}()
// Parse client private key
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
require.NoError(t, err)
hostPubKey := hostPrivParsed.PublicKey()
// (currentUser already obtained at function start)
// Create SSH client connection
clientConfig := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(signer),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 5 * time.Second,
}
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
require.NoError(t, err, "SSH connection should succeed")
defer func() {
if err := conn.Close(); err != nil {
t.Logf("connection close error: %v", err)
}
}()
// Try to create SFTP client - should fail when SFTP is disabled
_, err = sftp.NewClient(conn)
assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled")
}