mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
229 lines
5.7 KiB
Go
229 lines
5.7 KiB
Go
package server
|
|
|
|
import (
|
|
"context"
|
|
"fmt"
|
|
"net"
|
|
"net/netip"
|
|
"os"
|
|
"os/user"
|
|
"testing"
|
|
"time"
|
|
|
|
"github.com/pkg/sftp"
|
|
"github.com/stretchr/testify/assert"
|
|
"github.com/stretchr/testify/require"
|
|
cryptossh "golang.org/x/crypto/ssh"
|
|
|
|
"github.com/netbirdio/netbird/client/ssh"
|
|
)
|
|
|
|
func TestSSHServer_SFTPSubsystem(t *testing.T) {
|
|
// Skip SFTP test when running as root due to protocol issues in some environments
|
|
if os.Geteuid() == 0 {
|
|
t.Skip("Skipping SFTP test when running as root - may have protocol compatibility issues")
|
|
}
|
|
|
|
// Get current user for SSH connection
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err, "Should be able to get current user")
|
|
|
|
// Generate host key for server
|
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
// Generate client key pair
|
|
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
// Create server with SFTP enabled
|
|
serverConfig := &Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: nil,
|
|
}
|
|
server := New(serverConfig)
|
|
server.SetAllowSFTP(true)
|
|
server.SetAllowRootLogin(true)
|
|
|
|
// Start server
|
|
serverAddr := "127.0.0.1:0"
|
|
started := make(chan string, 1)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
ln, err := net.Listen("tcp", serverAddr)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
actualAddr := ln.Addr().String()
|
|
if err := ln.Close(); err != nil {
|
|
errChan <- fmt.Errorf("close temp listener: %w", err)
|
|
return
|
|
}
|
|
|
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
started <- actualAddr
|
|
}()
|
|
|
|
select {
|
|
case actualAddr := <-started:
|
|
serverAddr = actualAddr
|
|
case err := <-errChan:
|
|
t.Fatalf("Server failed to start: %v", err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Server start timeout")
|
|
}
|
|
|
|
defer func() {
|
|
err := server.Stop()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
// Parse client private key
|
|
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
|
require.NoError(t, err)
|
|
|
|
// Parse server host key
|
|
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
|
require.NoError(t, err)
|
|
hostPubKey := hostPrivParsed.PublicKey()
|
|
|
|
// (currentUser already obtained at function start)
|
|
|
|
// Create SSH client connection
|
|
clientConfig := &cryptossh.ClientConfig{
|
|
User: currentUser.Username,
|
|
Auth: []cryptossh.AuthMethod{
|
|
cryptossh.PublicKeys(signer),
|
|
},
|
|
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
|
|
require.NoError(t, err, "SSH connection should succeed")
|
|
defer func() {
|
|
if err := conn.Close(); err != nil {
|
|
t.Logf("connection close error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Create SFTP client
|
|
sftpClient, err := sftp.NewClient(conn)
|
|
require.NoError(t, err, "SFTP client creation should succeed")
|
|
defer func() {
|
|
if err := sftpClient.Close(); err != nil {
|
|
t.Logf("SFTP client close error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Test basic SFTP operations
|
|
workingDir, err := sftpClient.Getwd()
|
|
assert.NoError(t, err, "Should be able to get working directory")
|
|
assert.NotEmpty(t, workingDir, "Working directory should not be empty")
|
|
|
|
// Test directory listing
|
|
files, err := sftpClient.ReadDir(".")
|
|
assert.NoError(t, err, "Should be able to list current directory")
|
|
assert.NotNil(t, files, "File list should not be nil")
|
|
}
|
|
|
|
func TestSSHServer_SFTPDisabled(t *testing.T) {
|
|
// Get current user for SSH connection
|
|
currentUser, err := user.Current()
|
|
require.NoError(t, err, "Should be able to get current user")
|
|
|
|
// Generate host key for server
|
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
// Generate client key pair
|
|
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
require.NoError(t, err)
|
|
|
|
// Create server with SFTP disabled
|
|
serverConfig := &Config{
|
|
HostKeyPEM: hostKey,
|
|
JWT: nil,
|
|
}
|
|
server := New(serverConfig)
|
|
server.SetAllowSFTP(false)
|
|
|
|
// Start server
|
|
serverAddr := "127.0.0.1:0"
|
|
started := make(chan string, 1)
|
|
errChan := make(chan error, 1)
|
|
|
|
go func() {
|
|
ln, err := net.Listen("tcp", serverAddr)
|
|
if err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
actualAddr := ln.Addr().String()
|
|
if err := ln.Close(); err != nil {
|
|
errChan <- fmt.Errorf("close temp listener: %w", err)
|
|
return
|
|
}
|
|
|
|
addrPort, _ := netip.ParseAddrPort(actualAddr)
|
|
if err := server.Start(context.Background(), addrPort); err != nil {
|
|
errChan <- err
|
|
return
|
|
}
|
|
started <- actualAddr
|
|
}()
|
|
|
|
select {
|
|
case actualAddr := <-started:
|
|
serverAddr = actualAddr
|
|
case err := <-errChan:
|
|
t.Fatalf("Server failed to start: %v", err)
|
|
case <-time.After(5 * time.Second):
|
|
t.Fatal("Server start timeout")
|
|
}
|
|
|
|
defer func() {
|
|
err := server.Stop()
|
|
require.NoError(t, err)
|
|
}()
|
|
|
|
// Parse client private key
|
|
signer, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
|
require.NoError(t, err)
|
|
|
|
// Parse server host key
|
|
hostPrivParsed, err := cryptossh.ParsePrivateKey(hostKey)
|
|
require.NoError(t, err)
|
|
hostPubKey := hostPrivParsed.PublicKey()
|
|
|
|
// (currentUser already obtained at function start)
|
|
|
|
// Create SSH client connection
|
|
clientConfig := &cryptossh.ClientConfig{
|
|
User: currentUser.Username,
|
|
Auth: []cryptossh.AuthMethod{
|
|
cryptossh.PublicKeys(signer),
|
|
},
|
|
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
|
Timeout: 5 * time.Second,
|
|
}
|
|
|
|
conn, err := cryptossh.Dial("tcp", serverAddr, clientConfig)
|
|
require.NoError(t, err, "SSH connection should succeed")
|
|
defer func() {
|
|
if err := conn.Close(); err != nil {
|
|
t.Logf("connection close error: %v", err)
|
|
}
|
|
}()
|
|
|
|
// Try to create SFTP client - should fail when SFTP is disabled
|
|
_, err = sftp.NewClient(conn)
|
|
assert.Error(t, err, "SFTP client creation should fail when SFTP is disabled")
|
|
}
|