mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 09:16:40 +00:00
Add ssh authenatication with jwt (#4550)
This commit is contained in:
@@ -21,6 +21,8 @@ import (
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -40,12 +42,10 @@ type Client struct {
|
||||
windowsStdinMode uint32 // nolint:unused
|
||||
}
|
||||
|
||||
// Close terminates the SSH connection
|
||||
func (c *Client) Close() error {
|
||||
return c.client.Close()
|
||||
}
|
||||
|
||||
// OpenTerminal opens an interactive terminal session
|
||||
func (c *Client) OpenTerminal(ctx context.Context) error {
|
||||
session, err := c.client.NewSession()
|
||||
if err != nil {
|
||||
@@ -259,43 +259,29 @@ func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error
|
||||
return session, cleanup, nil
|
||||
}
|
||||
|
||||
// Dial connects to the given ssh server with proper host key verification
|
||||
func Dial(ctx context.Context, addr, user string) (*Client, error) {
|
||||
hostKeyCallback, err := createHostKeyCallback(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create host key callback: %w", err)
|
||||
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
|
||||
func getDefaultDaemonAddr() string {
|
||||
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
|
||||
return addr
|
||||
}
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Timeout: 30 * time.Second,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
if runtime.GOOS == "windows" {
|
||||
return DefaultDaemonAddrWindows
|
||||
}
|
||||
|
||||
return dial(ctx, "tcp", addr, config)
|
||||
}
|
||||
|
||||
// DialInsecure connects to the given ssh server without host key verification (for testing only)
|
||||
func DialInsecure(ctx context.Context, addr, user string) (*Client, error) {
|
||||
config := &ssh.ClientConfig{
|
||||
User: user,
|
||||
Timeout: 30 * time.Second,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 - Only used for tests
|
||||
}
|
||||
|
||||
return dial(ctx, "tcp", addr, config)
|
||||
return DefaultDaemonAddr
|
||||
}
|
||||
|
||||
// DialOptions contains options for SSH connections
|
||||
type DialOptions struct {
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
DaemonAddr string
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
DaemonAddr string
|
||||
SkipCachedToken bool
|
||||
InsecureSkipVerify bool
|
||||
}
|
||||
|
||||
// DialWithOptions connects to the given ssh server with specified options
|
||||
func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
||||
hostKeyCallback, err := createHostKeyCallbackWithOptions(addr, opts)
|
||||
// Dial connects to the given ssh server with specified options
|
||||
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
||||
hostKeyCallback, err := createHostKeyCallback(opts)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("create host key callback: %w", err)
|
||||
}
|
||||
@@ -306,7 +292,6 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
}
|
||||
|
||||
// Add SSH key authentication if identity file is specified
|
||||
if opts.IdentityFile != "" {
|
||||
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
|
||||
if err != nil {
|
||||
@@ -315,11 +300,16 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
|
||||
config.Auth = append(config.Auth, authMethod)
|
||||
}
|
||||
|
||||
return dial(ctx, "tcp", addr, config)
|
||||
daemonAddr := opts.DaemonAddr
|
||||
if daemonAddr == "" {
|
||||
daemonAddr = getDefaultDaemonAddr()
|
||||
}
|
||||
|
||||
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
||||
}
|
||||
|
||||
// dial establishes an SSH connection
|
||||
func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||
// dialSSH establishes an SSH connection without JWT authentication
|
||||
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||
dialer := &net.Dialer{}
|
||||
conn, err := dialer.DialContext(ctx, network, addr)
|
||||
if err != nil {
|
||||
@@ -340,143 +330,84 @@ func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (
|
||||
}, nil
|
||||
}
|
||||
|
||||
// createHostKeyCallback creates a host key verification callback that checks daemon first, then known_hosts files
|
||||
func createHostKeyCallback(addr string) (ssh.HostKeyCallback, error) {
|
||||
daemonAddr := os.Getenv("NB_DAEMON_ADDR")
|
||||
if daemonAddr == "" {
|
||||
if runtime.GOOS == "windows" {
|
||||
daemonAddr = DefaultDaemonAddrWindows
|
||||
} else {
|
||||
daemonAddr = DefaultDaemonAddr
|
||||
}
|
||||
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
||||
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
|
||||
host, portStr, err := net.SplitHostPort(addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
||||
}
|
||||
return createHostKeyCallbackWithDaemonAddr(addr, daemonAddr)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("SSH server detection failed: %w", err)
|
||||
}
|
||||
|
||||
if !serverType.RequiresJWT() {
|
||||
return dialSSH(ctx, network, addr, config)
|
||||
}
|
||||
|
||||
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||
defer cancel()
|
||||
|
||||
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("request JWT token: %w", err)
|
||||
}
|
||||
|
||||
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
|
||||
return dialSSH(ctx, network, addr, configWithJWT)
|
||||
}
|
||||
|
||||
// createHostKeyCallbackWithDaemonAddr creates a host key verification callback with specified daemon address
|
||||
func createHostKeyCallbackWithDaemonAddr(addr, daemonAddr string) (ssh.HostKeyCallback, error) {
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
// First try to get host key from NetBird daemon
|
||||
if err := verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr); err == nil {
|
||||
return nil
|
||||
}
|
||||
// requestJWTToken requests a JWT token from the NetBird daemon
|
||||
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
// Fallback to known_hosts files
|
||||
knownHostsFiles := getKnownHostsFiles()
|
||||
|
||||
var hostKeyCallbacks []ssh.HostKeyCallback
|
||||
|
||||
for _, file := range knownHostsFiles {
|
||||
if callback, err := knownhosts.New(file); err == nil {
|
||||
hostKeyCallbacks = append(hostKeyCallbacks, callback)
|
||||
}
|
||||
}
|
||||
|
||||
// Try each known_hosts callback
|
||||
for _, callback := range hostKeyCallbacks {
|
||||
if err := callback(hostname, remote, key); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
|
||||
}, nil
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache)
|
||||
}
|
||||
|
||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
client, err := connectToDaemon(daemonAddr)
|
||||
conn, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("daemon connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
addresses := buildAddressList(hostname, remote)
|
||||
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses)
|
||||
|
||||
return verifyKeyWithDaemon(client, addresses, key)
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
verifier := nbssh.NewDaemonHostKeyVerifier(client)
|
||||
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||
return callback(hostname, remote, key)
|
||||
}
|
||||
|
||||
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
addr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
ctx,
|
||||
strings.TrimPrefix(daemonAddr, "tcp://"),
|
||||
conn, err := grpc.NewClient(
|
||||
addr,
|
||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||
grpc.WithBlock(),
|
||||
)
|
||||
if err != nil {
|
||||
log.Debugf("failed to connect to NetBird daemon at %s: %v", daemonAddr, err)
|
||||
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
|
||||
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
||||
}
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
addresses := []string{hostname}
|
||||
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
||||
if host != hostname {
|
||||
addresses = append(addresses, host)
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
|
||||
func verifyKeyWithDaemon(conn *grpc.ClientConn, addresses []string, key ssh.PublicKey) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
for _, addr := range addresses {
|
||||
if err := checkAddressKey(client, addr, key); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("SSH host key not found or does not match in NetBird daemon")
|
||||
}
|
||||
|
||||
func checkAddressKey(client proto.DaemonServiceClient, addr string, key ssh.PublicKey) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
||||
PeerAddress: addr,
|
||||
})
|
||||
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("daemon query error for %s: %v", addr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !response.GetFound() {
|
||||
log.Debugf("SSH host key not found in daemon for address: %s", addr)
|
||||
return fmt.Errorf("key not found")
|
||||
}
|
||||
|
||||
return compareKeys(response.GetSshHostKey(), key, addr)
|
||||
}
|
||||
|
||||
func compareKeys(storedKeyData []byte, presentedKey ssh.PublicKey, addr string) error {
|
||||
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if presentedKey.Type() == storedKey.Type() && string(presentedKey.Marshal()) == string(storedKey.Marshal()) {
|
||||
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), presentedKey.Type())
|
||||
return fmt.Errorf("key mismatch")
|
||||
}
|
||||
|
||||
// getKnownHostsFiles returns paths to known_hosts files in order of preference
|
||||
func getKnownHostsFiles() []string {
|
||||
var files []string
|
||||
@@ -503,8 +434,12 @@ func getKnownHostsFiles() []string {
|
||||
return files
|
||||
}
|
||||
|
||||
// createHostKeyCallbackWithOptions creates a host key verification callback with custom options
|
||||
func createHostKeyCallbackWithOptions(addr string, opts DialOptions) (ssh.HostKeyCallback, error) {
|
||||
// createHostKeyCallback creates a host key verification callback
|
||||
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
|
||||
if opts.InsecureSkipVerify {
|
||||
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
|
||||
}
|
||||
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
|
||||
return nil
|
||||
|
||||
@@ -7,29 +7,36 @@ import (
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
cleanupTestUsers()
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -39,19 +46,14 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create and start server
|
||||
server := sshserver.New(hostKey)
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
@@ -62,8 +64,10 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := getCurrentUsername()
|
||||
client, err := DialInsecure(ctx, serverAddr, currentUser)
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client.Close()
|
||||
@@ -75,7 +79,7 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
||||
}
|
||||
|
||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||
if runtime.GOOS == "windows" && isCI() {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
@@ -129,20 +133,16 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
|
||||
}()
|
||||
|
||||
// Generate client key for multiple connections
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
err = server.AddAuthorizedKey("multi-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
const numClients = 3
|
||||
clients := make([]*Client, numClients)
|
||||
|
||||
for i := 0; i < numClients; i++ {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
currentUser := getCurrentUsername()
|
||||
client, err := DialInsecure(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i))
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i), DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
cancel()
|
||||
require.NoError(t, err, "Client %d should connect successfully", i)
|
||||
clients[i] = client
|
||||
@@ -161,19 +161,14 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
err = server.AddAuthorizedKey("cancel-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("connection with short timeout", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||
defer cancel()
|
||||
|
||||
currentUser := getCurrentUsername()
|
||||
_, err = DialInsecure(ctx, serverAddr, currentUser)
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
if err != nil {
|
||||
// Check for actual timeout-related errors rather than string matching
|
||||
assert.True(t,
|
||||
@@ -187,8 +182,10 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
|
||||
t.Run("command execution cancellation", func(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
currentUser := getCurrentUsername()
|
||||
client, err := DialInsecure(ctx, serverAddr, currentUser)
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
@@ -214,7 +211,11 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := sshserver.New(hostKey)
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
@@ -226,10 +227,12 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := getCurrentUsername()
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
|
||||
client, err := DialInsecure(ctx, serverAddr, currentUser)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
assert.NoError(t, err)
|
||||
if client != nil {
|
||||
require.NoError(t, client.Close(), "Client should close without error")
|
||||
@@ -282,24 +285,22 @@ func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Clie
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := sshserver.New(hostKey)
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
currentUser := getCurrentUsername()
|
||||
client, err := DialInsecure(ctx, serverAddr, currentUser)
|
||||
currentUser := testutil.GetTestUsername(t)
|
||||
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
return server, serverAddr, client
|
||||
@@ -361,18 +362,14 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := sshserver.New(hostKey)
|
||||
serverConfig := &sshserver.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := sshserver.New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(true)
|
||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
serverAddr := sshserver.StartTestServer(t, server)
|
||||
defer func() {
|
||||
err := server.Stop()
|
||||
@@ -387,11 +384,13 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Skip if running as system account that can't do port forwarding
|
||||
if isSystemAccount(realUser) {
|
||||
if testutil.IsSystemAccount(realUser) {
|
||||
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
|
||||
}
|
||||
|
||||
client, err := DialInsecure(ctx, serverAddr, realUser)
|
||||
client, err := Dial(ctx, serverAddr, realUser, DialOptions{
|
||||
InsecureSkipVerify: true, // Skip host key verification for test
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
@@ -478,180 +477,6 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
||||
assert.Equal(t, expectedResponse, string(response))
|
||||
}
|
||||
|
||||
// getCurrentUsername returns the current username for SSH connections
|
||||
func getCurrentUsername() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
// Check if this is a system account that can't authenticate
|
||||
if isSystemAccount(currentUser.Username) {
|
||||
// In CI environments, create a test user; otherwise try Administrator
|
||||
if isCI() {
|
||||
if testUser := getOrCreateTestUser(); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
} else {
|
||||
// Try Administrator first for local development
|
||||
if _, err := user.Lookup("Administrator"); err == nil {
|
||||
return "Administrator"
|
||||
}
|
||||
if testUser := getOrCreateTestUser(); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
}
|
||||
}
|
||||
// On Windows, return the full domain\username for proper authentication
|
||||
return currentUser.Username
|
||||
}
|
||||
}
|
||||
|
||||
if username := os.Getenv("USER"); username != "" {
|
||||
return username
|
||||
}
|
||||
|
||||
if currentUser, err := user.Current(); err == nil {
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
return "test-user"
|
||||
}
|
||||
|
||||
// isCI checks if we're running in GitHub Actions CI
|
||||
func isCI() bool {
|
||||
// Check standard CI environment variables
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
|
||||
hostname, err := os.Hostname()
|
||||
if err == nil && strings.HasPrefix(hostname, "runner") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// getOrCreateTestUser creates a test user on Windows if needed
|
||||
func getOrCreateTestUser() string {
|
||||
testUsername := "netbird-test-user"
|
||||
|
||||
// Check if user already exists
|
||||
if _, err := user.Lookup(testUsername); err == nil {
|
||||
return testUsername
|
||||
}
|
||||
|
||||
// Try to create the user using PowerShell
|
||||
if createWindowsTestUser(testUsername) {
|
||||
// Register cleanup for the test user
|
||||
registerTestUserCleanup(testUsername)
|
||||
return testUsername
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
var createdTestUsers = make(map[string]bool)
|
||||
var testUsersToCleanup []string
|
||||
|
||||
// registerTestUserCleanup registers a test user for cleanup
|
||||
func registerTestUserCleanup(username string) {
|
||||
if !createdTestUsers[username] {
|
||||
createdTestUsers[username] = true
|
||||
testUsersToCleanup = append(testUsersToCleanup, username)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupTestUsers removes all created test users
|
||||
func cleanupTestUsers() {
|
||||
for _, username := range testUsersToCleanup {
|
||||
removeWindowsTestUser(username)
|
||||
}
|
||||
testUsersToCleanup = nil
|
||||
createdTestUsers = make(map[string]bool)
|
||||
}
|
||||
|
||||
// removeWindowsTestUser removes a local user on Windows using PowerShell
|
||||
func removeWindowsTestUser(username string) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
// PowerShell command to remove a local user
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
Remove-LocalUser -Name "%s" -ErrorAction Stop
|
||||
Write-Output "User removed successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*cannot be found*") {
|
||||
Write-Output "User not found (already removed)"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
}
|
||||
}
|
||||
`, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
|
||||
} else {
|
||||
log.Printf("Test user %s cleanup result: %s", username, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// createWindowsTestUser creates a local user on Windows using PowerShell
|
||||
func createWindowsTestUser(username string) bool {
|
||||
if runtime.GOOS != "windows" {
|
||||
return false
|
||||
}
|
||||
|
||||
// PowerShell command to create a local user
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
|
||||
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
|
||||
Add-LocalGroupMember -Group "Users" -Member "%s"
|
||||
Write-Output "User created successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*already exists*") {
|
||||
Write-Output "User already exists"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
`, username, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to create test user: %v, output: %s", err, string(output))
|
||||
return false
|
||||
}
|
||||
|
||||
log.Printf("Test user creation result: %s", string(output))
|
||||
return true
|
||||
}
|
||||
|
||||
// isSystemAccount checks if the user is a system account that can't authenticate
|
||||
func isSystemAccount(username string) bool {
|
||||
systemAccounts := []string{
|
||||
"system",
|
||||
"NT AUTHORITY\\SYSTEM",
|
||||
"NT AUTHORITY\\LOCAL SERVICE",
|
||||
"NT AUTHORITY\\NETWORK SERVICE",
|
||||
}
|
||||
|
||||
for _, sysAccount := range systemAccounts {
|
||||
if strings.EqualFold(username, sysAccount) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
|
||||
func getRealCurrentUser() (string, error) {
|
||||
if runtime.GOOS == "windows" {
|
||||
|
||||
167
client/ssh/common.go
Normal file
167
client/ssh/common.go
Normal file
@@ -0,0 +1,167 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
NetBirdSSHConfigFile = "99-netbird.conf"
|
||||
|
||||
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
|
||||
WindowsSSHConfigDir = "ssh/ssh_config.d"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPeerNotFound indicates the peer was not found in the network
|
||||
ErrPeerNotFound = errors.New("peer not found in network")
|
||||
// ErrNoStoredKey indicates the peer has no stored SSH host key
|
||||
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
|
||||
)
|
||||
|
||||
// HostKeyVerifier provides SSH host key verification
|
||||
type HostKeyVerifier interface {
|
||||
VerifySSHHostKey(peerAddress string, key []byte) error
|
||||
}
|
||||
|
||||
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
|
||||
type DaemonHostKeyVerifier struct {
|
||||
client proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
|
||||
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
|
||||
return &DaemonHostKeyVerifier{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
|
||||
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
||||
PeerAddress: peerAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !response.GetFound() {
|
||||
return ErrPeerNotFound
|
||||
}
|
||||
|
||||
storedKeyData := response.GetSshHostKey()
|
||||
|
||||
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
||||
}
|
||||
|
||||
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
||||
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool) (string, error) {
|
||||
authResponse, err := client.RequestJWTAuth(ctx, &proto.RequestJWTAuthRequest{})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request JWT auth: %w", err)
|
||||
}
|
||||
|
||||
if useCache && authResponse.CachedToken != "" {
|
||||
log.Debug("Using cached authentication token")
|
||||
return authResponse.CachedToken, nil
|
||||
}
|
||||
|
||||
if stderr != nil {
|
||||
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
||||
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
|
||||
if authResponse.UserCode != "" {
|
||||
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
||||
}
|
||||
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
||||
}
|
||||
|
||||
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
||||
DeviceCode: authResponse.DeviceCode,
|
||||
UserCode: authResponse.UserCode,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wait for JWT token: %w", err)
|
||||
}
|
||||
|
||||
if stdout != nil {
|
||||
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
|
||||
}
|
||||
return tokenResponse.Token, nil
|
||||
}
|
||||
|
||||
// VerifyHostKey verifies an SSH host key against stored peer key data.
|
||||
// Returns nil only if the presented key matches the stored key.
|
||||
// Returns ErrNoStoredKey if storedKeyData is empty.
|
||||
// Returns an error if the keys don't match or if parsing fails.
|
||||
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
|
||||
if len(storedKeyData) == 0 {
|
||||
return ErrNoStoredKey
|
||||
}
|
||||
|
||||
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
|
||||
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddJWTAuth prepends JWT password authentication to existing auth methods.
|
||||
// This ensures JWT auth is tried first while preserving any existing auth methods.
|
||||
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
|
||||
configWithJWT := *config
|
||||
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
|
||||
return &configWithJWT
|
||||
}
|
||||
|
||||
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
|
||||
// It tries multiple addresses (hostname, IP) for the peer before failing.
|
||||
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
addresses := buildAddressList(hostname, remote)
|
||||
presentedKey := key.Marshal()
|
||||
|
||||
for _, addr := range addresses {
|
||||
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
|
||||
if errors.Is(err, ErrPeerNotFound) {
|
||||
// Try other addresses for this peer
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
// Verified
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// buildAddressList creates a list of addresses to check for host key verification.
|
||||
// It includes the original hostname and extracts the host part from the remote address if different.
|
||||
func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
addresses := []string{hostname}
|
||||
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
||||
if host != hostname {
|
||||
addresses = append(addresses, host)
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -12,50 +11,41 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisableSSHConfig is the environment variable to disable SSH config management
|
||||
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
|
||||
|
||||
// EnvForceSSHConfig is the environment variable to force SSH config generation even with many peers
|
||||
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
|
||||
|
||||
// MaxPeersForSSHConfig is the default maximum number of peers before SSH config generation is disabled
|
||||
MaxPeersForSSHConfig = 200
|
||||
|
||||
// fileWriteTimeout is the timeout for file write operations
|
||||
fileWriteTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// isSSHConfigDisabled checks if SSH config management is disabled via environment variable
|
||||
func isSSHConfigDisabled() bool {
|
||||
value := os.Getenv(EnvDisableSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse as boolean, default to true if non-empty but invalid
|
||||
disabled, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
// If not a valid boolean, treat any non-empty value as true
|
||||
return true
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
// isSSHConfigForced checks if SSH config generation is forced via environment variable
|
||||
func isSSHConfigForced() bool {
|
||||
value := os.Getenv(EnvForceSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse as boolean, default to true if non-empty but invalid
|
||||
forced, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
// If not a valid boolean, treat any non-empty value as true
|
||||
return true
|
||||
}
|
||||
return forced
|
||||
@@ -92,85 +82,55 @@ func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error
|
||||
}
|
||||
}
|
||||
|
||||
// writeFileOperationWithTimeout performs a file operation with timeout
|
||||
func writeFileOperationWithTimeout(filename string, operation func() error) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- operation()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
|
||||
}
|
||||
}
|
||||
|
||||
// Manager handles SSH client configuration for NetBird peers
|
||||
type Manager struct {
|
||||
sshConfigDir string
|
||||
sshConfigFile string
|
||||
knownHostsDir string
|
||||
knownHostsFile string
|
||||
userKnownHosts string
|
||||
sshConfigDir string
|
||||
sshConfigFile string
|
||||
}
|
||||
|
||||
// PeerHostKey represents a peer's SSH host key information
|
||||
type PeerHostKey struct {
|
||||
// PeerSSHInfo represents a peer's SSH configuration information
|
||||
type PeerSSHInfo struct {
|
||||
Hostname string
|
||||
IP string
|
||||
FQDN string
|
||||
HostKey ssh.PublicKey
|
||||
}
|
||||
|
||||
// NewManager creates a new SSH config manager
|
||||
func NewManager() *Manager {
|
||||
sshConfigDir, knownHostsDir := getSystemSSHPaths()
|
||||
// New creates a new SSH config manager
|
||||
func New() *Manager {
|
||||
sshConfigDir := getSystemSSHConfigDir()
|
||||
return &Manager{
|
||||
sshConfigDir: sshConfigDir,
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: knownHostsDir,
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
sshConfigDir: sshConfigDir,
|
||||
sshConfigFile: nbssh.NetBirdSSHConfigFile,
|
||||
}
|
||||
}
|
||||
|
||||
// getSystemSSHPaths returns platform-specific SSH configuration paths
|
||||
func getSystemSSHPaths() (configDir, knownHostsDir string) {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
configDir, knownHostsDir = getWindowsSSHPaths()
|
||||
default:
|
||||
// Unix-like systems (Linux, macOS, etc.)
|
||||
configDir = "/etc/ssh/ssh_config.d"
|
||||
knownHostsDir = "/etc/ssh/ssh_known_hosts.d"
|
||||
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
|
||||
func getSystemSSHConfigDir() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return getWindowsSSHConfigDir()
|
||||
}
|
||||
return configDir, knownHostsDir
|
||||
return nbssh.UnixSSHConfigDir
|
||||
}
|
||||
|
||||
func getWindowsSSHPaths() (configDir, knownHostsDir string) {
|
||||
func getWindowsSSHConfigDir() string {
|
||||
programData := os.Getenv("PROGRAMDATA")
|
||||
if programData == "" {
|
||||
programData = `C:\ProgramData`
|
||||
}
|
||||
configDir = filepath.Join(programData, "ssh", "ssh_config.d")
|
||||
knownHostsDir = filepath.Join(programData, "ssh", "ssh_known_hosts.d")
|
||||
return configDir, knownHostsDir
|
||||
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
|
||||
}
|
||||
|
||||
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
|
||||
func (m *Manager) SetupSSHClientConfig(peerKeys []PeerHostKey) error {
|
||||
if !shouldGenerateSSHConfig(len(peerKeys)) {
|
||||
m.logSkipReason(len(peerKeys))
|
||||
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
|
||||
if !shouldGenerateSSHConfig(len(peers)) {
|
||||
m.logSkipReason(len(peers))
|
||||
return nil
|
||||
}
|
||||
|
||||
knownHostsPath := m.getKnownHostsPath()
|
||||
sshConfig := m.buildSSHConfig(peerKeys, knownHostsPath)
|
||||
sshConfig, err := m.buildSSHConfig(peers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build SSH config: %w", err)
|
||||
}
|
||||
return m.writeSSHConfig(sshConfig)
|
||||
}
|
||||
|
||||
@@ -183,21 +143,24 @@ func (m *Manager) logSkipReason(peerCount int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) getKnownHostsPath() string {
|
||||
knownHostsPath, err := m.setupKnownHostsFile()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to setup known_hosts file: %v", err)
|
||||
return "/dev/null"
|
||||
}
|
||||
return knownHostsPath
|
||||
}
|
||||
|
||||
func (m *Manager) buildSSHConfig(peerKeys []PeerHostKey, knownHostsPath string) string {
|
||||
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
|
||||
sshConfig := m.buildConfigHeader()
|
||||
for _, peer := range peerKeys {
|
||||
sshConfig += m.buildPeerConfig(peer, knownHostsPath)
|
||||
|
||||
var allHostPatterns []string
|
||||
for _, peer := range peers {
|
||||
hostPatterns := m.buildHostPatterns(peer)
|
||||
allHostPatterns = append(allHostPatterns, hostPatterns...)
|
||||
}
|
||||
return sshConfig
|
||||
|
||||
if len(allHostPatterns) > 0 {
|
||||
peerConfig, err := m.buildPeerConfig(allHostPatterns)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sshConfig += peerConfig
|
||||
}
|
||||
|
||||
return sshConfig, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildConfigHeader() string {
|
||||
@@ -209,25 +172,49 @@ func (m *Manager) buildConfigHeader() string {
|
||||
"#\n\n"
|
||||
}
|
||||
|
||||
func (m *Manager) buildPeerConfig(peer PeerHostKey, knownHostsPath string) string {
|
||||
hostPatterns := m.buildHostPatterns(peer)
|
||||
if len(hostPatterns) == 0 {
|
||||
return ""
|
||||
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
uniquePatterns := make(map[string]bool)
|
||||
var deduplicatedPatterns []string
|
||||
for _, pattern := range allHostPatterns {
|
||||
if !uniquePatterns[pattern] {
|
||||
uniquePatterns[pattern] = true
|
||||
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
|
||||
}
|
||||
}
|
||||
|
||||
hostLine := strings.Join(hostPatterns, " ")
|
||||
execPath, err := m.getNetBirdExecutablePath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||
}
|
||||
|
||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||
config += " # NetBird peer-specific configuration\n"
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += m.buildHostKeyConfig(knownHostsPath)
|
||||
config += " LogLevel ERROR\n\n"
|
||||
return config
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
} else {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
|
||||
}
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||
config += " StrictHostKeyChecking no\n"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += " UserKnownHostsFile NUL\n"
|
||||
} else {
|
||||
config += " UserKnownHostsFile /dev/null\n"
|
||||
}
|
||||
|
||||
config += " CheckHostIP no\n"
|
||||
config += " LogLevel ERROR\n\n"
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
|
||||
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
var hostPatterns []string
|
||||
if peer.IP != "" {
|
||||
hostPatterns = append(hostPatterns, peer.IP)
|
||||
@@ -241,280 +228,55 @@ func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
|
||||
return hostPatterns
|
||||
}
|
||||
|
||||
func (m *Manager) buildHostKeyConfig(knownHostsPath string) string {
|
||||
if knownHostsPath == "/dev/null" {
|
||||
return " StrictHostKeyChecking no\n" +
|
||||
" UserKnownHostsFile /dev/null\n"
|
||||
}
|
||||
return " StrictHostKeyChecking yes\n" +
|
||||
fmt.Sprintf(" UserKnownHostsFile %s\n", knownHostsPath)
|
||||
}
|
||||
|
||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
|
||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||
log.Warnf("Failed to create SSH config directory %s: %v", m.sshConfigDir, err)
|
||||
return m.setupUserConfig(sshConfig)
|
||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||
}
|
||||
|
||||
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||
log.Warnf("Failed to write SSH config file %s: %v", sshConfigPath, err)
|
||||
return m.setupUserConfig(sshConfig)
|
||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||
}
|
||||
|
||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupUserConfig creates SSH config in user's directory as fallback
|
||||
func (m *Manager) setupUserConfig(sshConfig string) error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user home directory: %w", err)
|
||||
}
|
||||
|
||||
userSSHDir := filepath.Join(homeDir, ".ssh")
|
||||
userConfigPath := filepath.Join(userSSHDir, "config")
|
||||
|
||||
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
|
||||
return fmt.Errorf("create user SSH directory: %w", err)
|
||||
}
|
||||
|
||||
// Check if NetBird config already exists in user config
|
||||
exists, err := m.configExists(userConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing config: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
log.Debugf("NetBird SSH config already exists in %s", userConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Append NetBird config to user's SSH config with timeout
|
||||
if err := writeFileOperationWithTimeout(userConfigPath, func() error {
|
||||
file, err := os.OpenFile(userConfigPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open user SSH config: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Debugf("user SSH config file close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := fmt.Fprintf(file, "\n%s", sshConfig); err != nil {
|
||||
return fmt.Errorf("write to user SSH config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Added NetBird SSH config to user config: %s", userConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// configExists checks if NetBird SSH config already exists
|
||||
func (m *Manager) configExists(configPath string) (bool, error) {
|
||||
file, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("open SSH config file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.Contains(line, "NetBird SSH client configuration") {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, scanner.Err()
|
||||
}
|
||||
|
||||
// RemoveSSHClientConfig removes NetBird SSH configuration
|
||||
func (m *Manager) RemoveSSHClientConfig() error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
|
||||
// Remove system-wide config if it exists
|
||||
if err := os.Remove(sshConfigPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Warnf("Failed to remove system SSH config %s: %v", sshConfigPath, err)
|
||||
} else if err == nil {
|
||||
err := os.Remove(sshConfigPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
|
||||
}
|
||||
if err == nil {
|
||||
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
|
||||
}
|
||||
|
||||
// Also try to clean up user config
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user home directory: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
userConfigPath := filepath.Join(homeDir, ".ssh", "config")
|
||||
if err := m.removeFromUserConfig(userConfigPath); err != nil {
|
||||
log.Warnf("Failed to clean user SSH config: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeFromUserConfig removes NetBird section from user's SSH config
|
||||
func (m *Manager) removeFromUserConfig(configPath string) error {
|
||||
// This is complex to implement safely, so for now just log
|
||||
// In practice, the system-wide config takes precedence anyway
|
||||
log.Debugf("NetBird SSH config cleanup from user config not implemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupKnownHostsFile creates and returns the path to NetBird known_hosts file
|
||||
func (m *Manager) setupKnownHostsFile() (string, error) {
|
||||
// Try system-wide known_hosts first
|
||||
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
|
||||
if err := os.MkdirAll(m.knownHostsDir, 0755); err == nil {
|
||||
// Create empty file if it doesn't exist
|
||||
if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
|
||||
if err := writeFileWithTimeout(knownHostsPath, []byte("# NetBird SSH known hosts\n"), 0644); err == nil {
|
||||
log.Debugf("Created NetBird known_hosts file: %s", knownHostsPath)
|
||||
return knownHostsPath, nil
|
||||
}
|
||||
} else if err == nil {
|
||||
return knownHostsPath, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to user directory
|
||||
homeDir, err := os.UserHomeDir()
|
||||
func (m *Manager) getNetBirdExecutablePath() (string, error) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get user home directory: %w", err)
|
||||
return "", fmt.Errorf("retrieve executable path: %w", err)
|
||||
}
|
||||
|
||||
userSSHDir := filepath.Join(homeDir, ".ssh")
|
||||
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
|
||||
return "", fmt.Errorf("create user SSH directory: %w", err)
|
||||
}
|
||||
|
||||
userKnownHostsPath := filepath.Join(userSSHDir, m.userKnownHosts)
|
||||
if _, err := os.Stat(userKnownHostsPath); os.IsNotExist(err) {
|
||||
if err := writeFileWithTimeout(userKnownHostsPath, []byte("# NetBird SSH known hosts\n"), 0600); err != nil {
|
||||
return "", fmt.Errorf("create user known_hosts file: %w", err)
|
||||
}
|
||||
log.Debugf("Created NetBird user known_hosts file: %s", userKnownHostsPath)
|
||||
}
|
||||
|
||||
return userKnownHostsPath, nil
|
||||
}
|
||||
|
||||
// UpdatePeerHostKeys updates the known_hosts file with peer host keys
|
||||
func (m *Manager) UpdatePeerHostKeys(peerKeys []PeerHostKey) error {
|
||||
peerCount := len(peerKeys)
|
||||
|
||||
// Check if SSH config should be generated
|
||||
if !shouldGenerateSSHConfig(peerCount) {
|
||||
if isSSHConfigDisabled() {
|
||||
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
|
||||
} else {
|
||||
log.Infof("SSH known_hosts update skipped: too many peers (%d > %d). Use %s=true to force.",
|
||||
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
knownHostsPath, err := m.setupKnownHostsFile()
|
||||
realPath, err := filepath.EvalSymlinks(execPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup known_hosts file: %w", err)
|
||||
log.Debugf("symlink resolution failed: %v", err)
|
||||
return execPath, nil
|
||||
}
|
||||
|
||||
// Create updated known_hosts content - NetBird file should only contain NetBird entries
|
||||
var updatedContent strings.Builder
|
||||
updatedContent.WriteString("# NetBird SSH known hosts\n")
|
||||
updatedContent.WriteString("# Generated automatically - do not edit manually\n\n")
|
||||
|
||||
// Add new NetBird entries - one entry per peer with all hostnames
|
||||
for _, peerKey := range peerKeys {
|
||||
entry := m.formatKnownHostsEntry(peerKey)
|
||||
updatedContent.WriteString(entry)
|
||||
updatedContent.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write updated content
|
||||
if err := writeFileWithTimeout(knownHostsPath, []byte(updatedContent.String()), 0644); err != nil {
|
||||
return fmt.Errorf("write known_hosts file: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Updated NetBird known_hosts with %d peer keys: %s", len(peerKeys), knownHostsPath)
|
||||
return nil
|
||||
return realPath, nil
|
||||
}
|
||||
|
||||
// formatKnownHostsEntry formats a peer host key as a known_hosts entry
|
||||
func (m *Manager) formatKnownHostsEntry(peerKey PeerHostKey) string {
|
||||
hostnames := m.getHostnameVariants(peerKey)
|
||||
hostnameList := strings.Join(hostnames, ",")
|
||||
keyString := string(ssh.MarshalAuthorizedKey(peerKey.HostKey))
|
||||
keyString = strings.TrimSpace(keyString)
|
||||
return fmt.Sprintf("%s %s", hostnameList, keyString)
|
||||
// GetSSHConfigDir returns the SSH config directory path
|
||||
func (m *Manager) GetSSHConfigDir() string {
|
||||
return m.sshConfigDir
|
||||
}
|
||||
|
||||
// getHostnameVariants returns all possible hostname variants for a peer
|
||||
func (m *Manager) getHostnameVariants(peerKey PeerHostKey) []string {
|
||||
var hostnames []string
|
||||
|
||||
// Add IP address
|
||||
if peerKey.IP != "" {
|
||||
hostnames = append(hostnames, peerKey.IP)
|
||||
}
|
||||
|
||||
// Add FQDN
|
||||
if peerKey.FQDN != "" {
|
||||
hostnames = append(hostnames, peerKey.FQDN)
|
||||
}
|
||||
|
||||
// Add hostname if different from FQDN
|
||||
if peerKey.Hostname != "" && peerKey.Hostname != peerKey.FQDN {
|
||||
hostnames = append(hostnames, peerKey.Hostname)
|
||||
}
|
||||
|
||||
// Add bracketed IP for non-standard ports (SSH standard)
|
||||
if peerKey.IP != "" {
|
||||
hostnames = append(hostnames, fmt.Sprintf("[%s]:22", peerKey.IP))
|
||||
hostnames = append(hostnames, fmt.Sprintf("[%s]:22022", peerKey.IP))
|
||||
}
|
||||
|
||||
return hostnames
|
||||
}
|
||||
|
||||
// GetKnownHostsPath returns the path to the NetBird known_hosts file
|
||||
func (m *Manager) GetKnownHostsPath() (string, error) {
|
||||
return m.setupKnownHostsFile()
|
||||
}
|
||||
|
||||
// RemoveKnownHostsFile removes the NetBird known_hosts file
|
||||
func (m *Manager) RemoveKnownHostsFile() error {
|
||||
// Remove system-wide known_hosts if it exists
|
||||
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
|
||||
if err := os.Remove(knownHostsPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Warnf("Failed to remove system known_hosts %s: %v", knownHostsPath, err)
|
||||
} else if err == nil {
|
||||
log.Infof("Removed NetBird known_hosts: %s", knownHostsPath)
|
||||
}
|
||||
|
||||
// Also try to clean up user known_hosts
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user home directory: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
userKnownHostsPath := filepath.Join(homeDir, ".ssh", m.userKnownHosts)
|
||||
if err := os.Remove(userKnownHostsPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Warnf("Failed to remove user known_hosts %s: %v", userKnownHostsPath, err)
|
||||
} else if err == nil {
|
||||
log.Infof("Removed NetBird user known_hosts: %s", userKnownHostsPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
// GetSSHConfigFile returns the SSH config file name
|
||||
func (m *Manager) GetSSHConfigFile() string {
|
||||
return m.sshConfigFile
|
||||
}
|
||||
|
||||
@@ -10,81 +10,8 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestManager_UpdatePeerHostKeys(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
}
|
||||
|
||||
// Generate test host keys
|
||||
hostKey1, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
pubKey1, err := ssh.ParsePrivateKey(hostKey1)
|
||||
require.NoError(t, err)
|
||||
|
||||
hostKey2, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
pubKey2, err := ssh.ParsePrivateKey(hostKey2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test peer host keys
|
||||
peerKeys := []PeerHostKey{
|
||||
{
|
||||
Hostname: "peer1",
|
||||
IP: "100.125.1.1",
|
||||
FQDN: "peer1.nb.internal",
|
||||
HostKey: pubKey1.PublicKey(),
|
||||
},
|
||||
{
|
||||
Hostname: "peer2",
|
||||
IP: "100.125.1.2",
|
||||
FQDN: "peer2.nb.internal",
|
||||
HostKey: pubKey2.PublicKey(),
|
||||
},
|
||||
}
|
||||
|
||||
// Test updating known_hosts
|
||||
err = manager.UpdatePeerHostKeys(peerKeys)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify known_hosts file was created and contains entries
|
||||
knownHostsPath, err := manager.GetKnownHostsPath()
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(knownHostsPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentStr := string(content)
|
||||
assert.Contains(t, contentStr, "100.125.1.1")
|
||||
assert.Contains(t, contentStr, "100.125.1.2")
|
||||
assert.Contains(t, contentStr, "peer1.nb.internal")
|
||||
assert.Contains(t, contentStr, "peer2.nb.internal")
|
||||
assert.Contains(t, contentStr, "[100.125.1.1]:22")
|
||||
assert.Contains(t, contentStr, "[100.125.1.1]:22022")
|
||||
|
||||
// Test updating with empty list should preserve structure
|
||||
err = manager.UpdatePeerHostKeys([]PeerHostKey{})
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err = os.ReadFile(knownHostsPath)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(content), "# NetBird SSH known hosts")
|
||||
}
|
||||
|
||||
func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
@@ -93,15 +20,25 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Test SSH config generation with empty peer keys
|
||||
err = manager.SetupSSHClientConfig(nil)
|
||||
// Test SSH config generation with peers
|
||||
peers := []PeerSSHInfo{
|
||||
{
|
||||
Hostname: "peer1",
|
||||
IP: "100.125.1.1",
|
||||
FQDN: "peer1.nb.internal",
|
||||
},
|
||||
{
|
||||
Hostname: "peer2",
|
||||
IP: "100.125.1.2",
|
||||
FQDN: "peer2.nb.internal",
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read generated config
|
||||
@@ -111,134 +48,39 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
|
||||
configStr := string(content)
|
||||
|
||||
// Since we now use per-peer configurations instead of domain patterns,
|
||||
// we should verify the basic SSH config structure exists
|
||||
// Verify the basic SSH config structure exists
|
||||
assert.Contains(t, configStr, "# NetBird SSH client configuration")
|
||||
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
|
||||
|
||||
// Should not contain /dev/null since we have a proper known_hosts setup
|
||||
assert.NotContains(t, configStr, "UserKnownHostsFile /dev/null")
|
||||
}
|
||||
// Check that peer hostnames are included
|
||||
assert.Contains(t, configStr, "100.125.1.1")
|
||||
assert.Contains(t, configStr, "100.125.1.2")
|
||||
assert.Contains(t, configStr, "peer1.nb.internal")
|
||||
assert.Contains(t, configStr, "peer2.nb.internal")
|
||||
|
||||
func TestManager_GetHostnameVariants(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
peerKey := PeerHostKey{
|
||||
Hostname: "testpeer",
|
||||
IP: "100.125.1.10",
|
||||
FQDN: "testpeer.nb.internal",
|
||||
HostKey: nil, // Not needed for this test
|
||||
}
|
||||
|
||||
variants := manager.getHostnameVariants(peerKey)
|
||||
|
||||
expectedVariants := []string{
|
||||
"100.125.1.10",
|
||||
"testpeer.nb.internal",
|
||||
"testpeer",
|
||||
"[100.125.1.10]:22",
|
||||
"[100.125.1.10]:22022",
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, expectedVariants, variants)
|
||||
}
|
||||
|
||||
func TestManager_FormatKnownHostsEntry(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
// Generate test key
|
||||
hostKeyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
parsedKey, err := ssh.ParsePrivateKey(hostKeyPEM)
|
||||
require.NoError(t, err)
|
||||
|
||||
peerKey := PeerHostKey{
|
||||
Hostname: "testpeer",
|
||||
IP: "100.125.1.10",
|
||||
FQDN: "testpeer.nb.internal",
|
||||
HostKey: parsedKey.PublicKey(),
|
||||
}
|
||||
|
||||
entry := manager.formatKnownHostsEntry(peerKey)
|
||||
|
||||
// Should contain all hostname variants
|
||||
assert.Contains(t, entry, "100.125.1.10")
|
||||
assert.Contains(t, entry, "testpeer.nb.internal")
|
||||
assert.Contains(t, entry, "testpeer")
|
||||
assert.Contains(t, entry, "[100.125.1.10]:22")
|
||||
assert.Contains(t, entry, "[100.125.1.10]:22022")
|
||||
|
||||
// Should contain the public key
|
||||
keyString := string(ssh.MarshalAuthorizedKey(parsedKey.PublicKey()))
|
||||
keyString = strings.TrimSpace(keyString)
|
||||
assert.Contains(t, entry, keyString)
|
||||
|
||||
// Should be properly formatted (hostnames followed by key)
|
||||
parts := strings.Fields(entry)
|
||||
assert.GreaterOrEqual(t, len(parts), 2, "Entry should have hostnames and key parts")
|
||||
}
|
||||
|
||||
func TestManager_DirectoryFallback(t *testing.T) {
|
||||
// Create temporary directory for test where system dirs will fail
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Set HOME to temp directory to control user fallback
|
||||
t.Setenv("HOME", tempDir)
|
||||
|
||||
// Create manager with non-writable system directories
|
||||
// Use paths that will fail on all systems
|
||||
var failPath string
|
||||
// Check platform-specific UserKnownHostsFile
|
||||
if runtime.GOOS == "windows" {
|
||||
failPath = "NUL:" // Special device that can't be used as directory on Windows
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile NUL")
|
||||
} else {
|
||||
failPath = "/dev/null" // Special device that can't be used as directory on Unix
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
|
||||
}
|
||||
|
||||
manager := &Manager{
|
||||
sshConfigDir: failPath + "/ssh_config.d", // Should fail
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: failPath + "/ssh_known_hosts.d", // Should fail
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
}
|
||||
|
||||
// Should fall back to user directory
|
||||
knownHostsPath, err := manager.setupKnownHostsFile()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the actual user home directory as determined by os.UserHomeDir()
|
||||
userHome, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedUserPath := filepath.Join(userHome, ".ssh", "known_hosts_netbird")
|
||||
assert.Equal(t, expectedUserPath, knownHostsPath)
|
||||
|
||||
// Verify file was created
|
||||
_, err = os.Stat(knownHostsPath)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetSystemSSHPaths(t *testing.T) {
|
||||
configDir, knownHostsDir := getSystemSSHPaths()
|
||||
func TestGetSystemSSHConfigDir(t *testing.T) {
|
||||
configDir := getSystemSSHConfigDir()
|
||||
|
||||
// Paths should not be empty
|
||||
// Path should not be empty
|
||||
assert.NotEmpty(t, configDir)
|
||||
assert.NotEmpty(t, knownHostsDir)
|
||||
|
||||
// Should be absolute paths
|
||||
// Should be an absolute path
|
||||
assert.True(t, filepath.IsAbs(configDir))
|
||||
assert.True(t, filepath.IsAbs(knownHostsDir))
|
||||
|
||||
// On Unix systems, should start with /etc
|
||||
// On Windows, should contain ProgramData
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, strings.ToLower(configDir), "programdata")
|
||||
assert.Contains(t, strings.ToLower(knownHostsDir), "programdata")
|
||||
} else {
|
||||
assert.Contains(t, configDir, "/etc/ssh")
|
||||
assert.Contains(t, knownHostsDir, "/etc/ssh")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,46 +92,28 @@ func TestManager_PeerLimit(t *testing.T) {
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peer keys (more than limit)
|
||||
var peerKeys []PeerHostKey
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
pubKey, err := ssh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
peerKeys = append(peerKeys, PeerHostKey{
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
HostKey: pubKey.PublicKey(),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is skipped when too many peers
|
||||
err = manager.SetupSSHClientConfig(peerKeys)
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should not be created due to peer limit
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
_, err = os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||
|
||||
// Test that known_hosts update is also skipped
|
||||
err = manager.UpdatePeerHostKeys(peerKeys)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Known hosts should not be created due to peer limit
|
||||
knownHostsPath := filepath.Join(manager.knownHostsDir, manager.knownHostsFile)
|
||||
_, err = os.Stat(knownHostsPath)
|
||||
assert.True(t, os.IsNotExist(err), "Known hosts should not be created with too many peers")
|
||||
}
|
||||
|
||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
@@ -303,31 +127,22 @@ func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peer keys (more than limit)
|
||||
var peerKeys []PeerHostKey
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
pubKey, err := ssh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
peerKeys = append(peerKeys, PeerHostKey{
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
HostKey: pubKey.PublicKey(),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is forced despite many peers
|
||||
err = manager.SetupSSHClientConfig(peerKeys)
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should be created despite peer limit due to force flag
|
||||
|
||||
22
client/ssh/config/shutdown_state.go
Normal file
22
client/ssh/config/shutdown_state.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package config
|
||||
|
||||
// ShutdownState represents SSH configuration state that needs to be cleaned up.
|
||||
type ShutdownState struct {
|
||||
SSHConfigDir string
|
||||
SSHConfigFile string
|
||||
}
|
||||
|
||||
// Name returns the state name for the state manager.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "ssh_config_state"
|
||||
}
|
||||
|
||||
// Cleanup removes SSH client configuration files.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
manager := &Manager{
|
||||
sshConfigDir: s.SSHConfigDir,
|
||||
sshConfigFile: s.SSHConfigFile,
|
||||
}
|
||||
|
||||
return manager.RemoveSSHClientConfig()
|
||||
}
|
||||
99
client/ssh/detection/detection.go
Normal file
99
client/ssh/detection/detection.go
Normal file
@@ -0,0 +1,99 @@
|
||||
package detection
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"net"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const (
|
||||
// ServerIdentifier is the base response for NetBird SSH servers
|
||||
ServerIdentifier = "NetBird-SSH-Server"
|
||||
// ProxyIdentifier is the base response for NetBird SSH proxy
|
||||
ProxyIdentifier = "NetBird-SSH-Proxy"
|
||||
// JWTRequiredMarker is appended to responses when JWT is required
|
||||
JWTRequiredMarker = "NetBird-JWT-Required"
|
||||
|
||||
// Timeout is the timeout for SSH server detection
|
||||
Timeout = 5 * time.Second
|
||||
)
|
||||
|
||||
type ServerType string
|
||||
|
||||
const (
|
||||
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
|
||||
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
|
||||
ServerTypeRegular ServerType = "regular"
|
||||
)
|
||||
|
||||
// Dialer provides network connection capabilities
|
||||
type Dialer interface {
|
||||
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||
}
|
||||
|
||||
// RequiresJWT checks if the server type requires JWT authentication
|
||||
func (s ServerType) RequiresJWT() bool {
|
||||
return s == ServerTypeNetBirdJWT
|
||||
}
|
||||
|
||||
// ExitCode returns the exit code for the detect command
|
||||
func (s ServerType) ExitCode() int {
|
||||
switch s {
|
||||
case ServerTypeNetBirdJWT:
|
||||
return 0
|
||||
case ServerTypeNetBirdNoJWT:
|
||||
return 1
|
||||
case ServerTypeRegular:
|
||||
return 2
|
||||
default:
|
||||
return 2
|
||||
}
|
||||
}
|
||||
|
||||
// DetectSSHServerType detects SSH server type using the provided dialer
|
||||
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
|
||||
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
|
||||
if err != nil {
|
||||
log.Debugf("SSH connection failed for detection: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
|
||||
log.Debugf("set read deadline: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
reader := bufio.NewReader(conn)
|
||||
serverBanner, err := reader.ReadString('\n')
|
||||
if err != nil {
|
||||
log.Debugf("read SSH banner: %v", err)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
serverBanner = strings.TrimSpace(serverBanner)
|
||||
log.Debugf("SSH server banner: %s", serverBanner)
|
||||
|
||||
if !strings.HasPrefix(serverBanner, "SSH-") {
|
||||
log.Debugf("Invalid SSH banner")
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if !strings.Contains(serverBanner, ServerIdentifier) {
|
||||
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
|
||||
return ServerTypeRegular, nil
|
||||
}
|
||||
|
||||
if strings.Contains(serverBanner, JWTRequiredMarker) {
|
||||
return ServerTypeNetBirdJWT, nil
|
||||
}
|
||||
|
||||
return ServerTypeNetBirdNoJWT, nil
|
||||
}
|
||||
359
client/ssh/proxy/proxy.go
Normal file
359
client/ssh/proxy/proxy.go
Normal file
@@ -0,0 +1,359 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"os"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
const (
|
||||
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
|
||||
sshConnectionTimeout = 120 * time.Second
|
||||
// sshHandshakeTimeout is the timeout for SSH handshake completion
|
||||
sshHandshakeTimeout = 30 * time.Second
|
||||
|
||||
jwtAuthErrorMsg = "JWT authentication: %w"
|
||||
)
|
||||
|
||||
type SSHProxy struct {
|
||||
daemonAddr string
|
||||
targetHost string
|
||||
targetPort int
|
||||
stderr io.Writer
|
||||
daemonClient proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
|
||||
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
|
||||
return &SSHProxy{
|
||||
daemonAddr: daemonAddr,
|
||||
targetHost: targetHost,
|
||||
targetPort: targetPort,
|
||||
stderr: stderr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) Connect(ctx context.Context) error {
|
||||
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf(jwtAuthErrorMsg, err)
|
||||
}
|
||||
|
||||
return p.runProxySSHServer(ctx, jwtToken)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
|
||||
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
|
||||
|
||||
sshServer := &ssh.Server{
|
||||
Handler: func(s ssh.Session) {
|
||||
p.handleSSHSession(ctx, s, jwtToken)
|
||||
},
|
||||
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||
"session": ssh.DefaultSessionHandler,
|
||||
"direct-tcpip": p.directTCPIPHandler,
|
||||
},
|
||||
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||
"sftp": func(s ssh.Session) {
|
||||
p.sftpSubsystemHandler(s, jwtToken)
|
||||
},
|
||||
},
|
||||
RequestHandlers: map[string]ssh.RequestHandler{
|
||||
"tcpip-forward": p.tcpipForwardHandler,
|
||||
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
|
||||
},
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
hostKey, err := generateHostKey()
|
||||
if err != nil {
|
||||
return fmt.Errorf("generate host key: %w", err)
|
||||
}
|
||||
sshServer.HostSigners = []ssh.Signer{hostKey}
|
||||
|
||||
conn := &stdioConn{
|
||||
stdin: os.Stdin,
|
||||
stdout: os.Stdout,
|
||||
}
|
||||
|
||||
sshServer.HandleConn(conn)
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
|
||||
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||
|
||||
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = sshClient.Close() }()
|
||||
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||
return
|
||||
}
|
||||
defer func() { _ = serverSession.Close() }()
|
||||
|
||||
serverSession.Stdin = session
|
||||
serverSession.Stdout = session
|
||||
serverSession.Stderr = session.Stderr()
|
||||
|
||||
ptyReq, winCh, isPty := session.Pty()
|
||||
if isPty {
|
||||
_ = serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil)
|
||||
|
||||
go func() {
|
||||
for win := range winCh {
|
||||
_ = serverSession.WindowChange(win.Height, win.Width)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
if len(session.Command()) > 0 {
|
||||
_ = serverSession.Run(strings.Join(session.Command(), " "))
|
||||
return
|
||||
}
|
||||
|
||||
if err = serverSession.Shell(); err == nil {
|
||||
_ = serverSession.Wait()
|
||||
}
|
||||
}
|
||||
|
||||
func generateHostKey() (ssh.Signer, error) {
|
||||
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("generate ED25519 key: %w", err)
|
||||
}
|
||||
|
||||
signer, err := cryptossh.ParsePrivateKey(keyPEM)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
return signer, nil
|
||||
}
|
||||
|
||||
type stdioConn struct {
|
||||
stdin io.Reader
|
||||
stdout io.Writer
|
||||
closed bool
|
||||
mu sync.Mutex
|
||||
}
|
||||
|
||||
func (c *stdioConn) Read(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, io.EOF
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.stdin.Read(b)
|
||||
}
|
||||
|
||||
func (c *stdioConn) Write(b []byte) (n int, err error) {
|
||||
c.mu.Lock()
|
||||
if c.closed {
|
||||
c.mu.Unlock()
|
||||
return 0, io.ErrClosedPipe
|
||||
}
|
||||
c.mu.Unlock()
|
||||
return c.stdout.Write(b)
|
||||
}
|
||||
|
||||
func (c *stdioConn) Close() error {
|
||||
c.mu.Lock()
|
||||
defer c.mu.Unlock()
|
||||
c.closed = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) LocalAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *stdioConn) RemoteAddr() net.Addr {
|
||||
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
|
||||
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
|
||||
}
|
||||
|
||||
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||
ctx, cancel := context.WithCancel(s.Context())
|
||||
defer cancel()
|
||||
|
||||
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||
|
||||
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := sshClient.Close(); err != nil {
|
||||
log.Debugf("close SSH client: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
serverSession, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
defer func() {
|
||||
if err := serverSession.Close(); err != nil {
|
||||
log.Debugf("close server session: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
stdin, stdout, err := p.setupSFTPPipes(serverSession)
|
||||
if err != nil {
|
||||
log.Debugf("setup SFTP pipes: %v", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
if err := serverSession.RequestSubsystem("sftp"); err != nil {
|
||||
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
|
||||
_ = s.Exit(1)
|
||||
return
|
||||
}
|
||||
|
||||
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
|
||||
}
|
||||
|
||||
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
|
||||
stdin, err := serverSession.StdinPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
|
||||
}
|
||||
|
||||
stdout, err := serverSession.StdoutPipe()
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
return stdin, stdout, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
|
||||
copyErrCh := make(chan error, 2)
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(stdin, s)
|
||||
if err != nil {
|
||||
log.Debugf("SFTP client to server copy: %v", err)
|
||||
}
|
||||
if err := stdin.Close(); err != nil {
|
||||
log.Debugf("close stdin: %v", err)
|
||||
}
|
||||
copyErrCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
_, err := io.Copy(s, stdout)
|
||||
if err != nil {
|
||||
log.Debugf("SFTP server to client copy: %v", err)
|
||||
}
|
||||
copyErrCh <- err
|
||||
}()
|
||||
|
||||
go func() {
|
||||
<-ctx.Done()
|
||||
if err := serverSession.Close(); err != nil {
|
||||
log.Debugf("force close server session on context cancellation: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 2; i++ {
|
||||
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("SFTP copy error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
if err := serverSession.Wait(); err != nil {
|
||||
log.Debugf("SFTP session ended: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||
return false, []byte("port forwarding not supported in proxy")
|
||||
}
|
||||
|
||||
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: user,
|
||||
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
|
||||
Timeout: sshHandshakeTimeout,
|
||||
HostKeyCallback: p.verifyHostKey,
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{
|
||||
Timeout: sshConnectionTimeout,
|
||||
}
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("connect to server: %w", err)
|
||||
}
|
||||
|
||||
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
|
||||
if err != nil {
|
||||
_ = conn.Close()
|
||||
return nil, fmt.Errorf("SSH handshake: %w", err)
|
||||
}
|
||||
|
||||
return cryptossh.NewClient(clientConn, chans, reqs), nil
|
||||
}
|
||||
|
||||
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
|
||||
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
|
||||
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||
return callback(hostname, remote, key)
|
||||
}
|
||||
361
client/ssh/proxy/proxy_test.go
Normal file
361
client/ssh/proxy/proxy_test.go
Normal file
@@ -0,0 +1,361 @@
|
||||
package proxy
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
)
|
||||
|
||||
func TestMain(m *testing.M) {
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" {
|
||||
if os.Args[2] == "exec" {
|
||||
if len(os.Args) > 3 {
|
||||
cmd := os.Args[3]
|
||||
if cmd == "echo" && len(os.Args) > 4 {
|
||||
fmt.Fprintln(os.Stdout, os.Args[4])
|
||||
os.Exit(0)
|
||||
}
|
||||
}
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
||||
|
||||
code := m.Run()
|
||||
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
|
||||
func TestSSHProxy_verifyHostKey(t *testing.T) {
|
||||
t.Run("calls daemon to verify host key", func(t *testing.T) {
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
proxy := &SSHProxy{
|
||||
daemonAddr: mockDaemon.addr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}
|
||||
|
||||
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
testPubKey, err := nbssh.GeneratePublicKey(testKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey("test-host", testPubKey)
|
||||
|
||||
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
|
||||
assert.NoError(t, err)
|
||||
})
|
||||
|
||||
t.Run("rejects unknown host key", func(t *testing.T) {
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
require.NoError(t, err)
|
||||
defer func() { _ = grpcConn.Close() }()
|
||||
|
||||
proxy := &SSHProxy{
|
||||
daemonAddr: mockDaemon.addr,
|
||||
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||
}
|
||||
|
||||
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
|
||||
})
|
||||
}
|
||||
|
||||
func TestSSHProxy_Connect(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping integration test in short mode")
|
||||
}
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
serverConfig := &server.Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: &server.JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
},
|
||||
}
|
||||
sshServer := server.New(serverConfig)
|
||||
sshServer.SetAllowRootLogin(true)
|
||||
|
||||
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||
defer func() { _ = sshServer.Stop() }()
|
||||
|
||||
mockDaemon := startMockDaemon(t)
|
||||
defer mockDaemon.stop()
|
||||
|
||||
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
mockDaemon.setHostKey(host, hostPubKey)
|
||||
|
||||
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
||||
mockDaemon.setJWTToken(validToken)
|
||||
|
||||
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
|
||||
require.NoError(t, err)
|
||||
|
||||
clientConn, proxyConn := net.Pipe()
|
||||
defer func() { _ = clientConn.Close() }()
|
||||
|
||||
origStdin := os.Stdin
|
||||
origStdout := os.Stdout
|
||||
defer func() {
|
||||
os.Stdin = origStdin
|
||||
os.Stdout = origStdout
|
||||
}()
|
||||
|
||||
stdinReader, stdinWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||
require.NoError(t, err)
|
||||
|
||||
os.Stdin = stdinReader
|
||||
os.Stdout = stdoutWriter
|
||||
|
||||
go func() {
|
||||
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||
}()
|
||||
go func() {
|
||||
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||
}()
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
connectErrCh := make(chan error, 1)
|
||||
go func() {
|
||||
connectErrCh <- proxyInstance.Connect(ctx)
|
||||
}()
|
||||
|
||||
sshConfig := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||
require.NoError(t, err, "Should connect to proxy server")
|
||||
defer func() { _ = sshClientConn.Close() }()
|
||||
|
||||
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||
|
||||
outputCh := make(chan []byte, 1)
|
||||
errCh := make(chan error, 1)
|
||||
go func() {
|
||||
output, err := session.Output("echo hello-from-proxy")
|
||||
outputCh <- output
|
||||
errCh <- err
|
||||
}()
|
||||
|
||||
select {
|
||||
case output := <-outputCh:
|
||||
err := <-errCh
|
||||
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||
case <-time.After(3 * time.Second):
|
||||
t.Fatal("Command execution timed out")
|
||||
}
|
||||
|
||||
_ = session.Close()
|
||||
_ = sshClient.Close()
|
||||
_ = clientConn.Close()
|
||||
cancel()
|
||||
}
|
||||
|
||||
type mockDaemonServer struct {
|
||||
proto.UnimplementedDaemonServiceServer
|
||||
hostKeys map[string][]byte
|
||||
jwtToken string
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
|
||||
key, found := m.hostKeys[req.PeerAddress]
|
||||
return &proto.GetPeerSSHHostKeyResponse{
|
||||
Found: found,
|
||||
SshHostKey: key,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
|
||||
return &proto.RequestJWTAuthResponse{
|
||||
CachedToken: m.jwtToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
|
||||
return &proto.WaitJWTTokenResponse{
|
||||
Token: m.jwtToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
type mockDaemon struct {
|
||||
addr string
|
||||
server *grpc.Server
|
||||
impl *mockDaemonServer
|
||||
}
|
||||
|
||||
func startMockDaemon(t *testing.T) *mockDaemon {
|
||||
t.Helper()
|
||||
|
||||
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
|
||||
impl := &mockDaemonServer{
|
||||
hostKeys: make(map[string][]byte),
|
||||
jwtToken: "test-jwt-token",
|
||||
}
|
||||
|
||||
grpcServer := grpc.NewServer()
|
||||
proto.RegisterDaemonServiceServer(grpcServer, impl)
|
||||
|
||||
go func() {
|
||||
_ = grpcServer.Serve(listener)
|
||||
}()
|
||||
|
||||
return &mockDaemon{
|
||||
addr: listener.Addr().String(),
|
||||
server: grpcServer,
|
||||
impl: impl,
|
||||
}
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
||||
m.impl.hostKeys[addr] = pubKey
|
||||
}
|
||||
|
||||
func (m *mockDaemon) setJWTToken(token string) {
|
||||
m.impl.jwtToken = token
|
||||
}
|
||||
|
||||
func (m *mockDaemon) stop() {
|
||||
if m.server != nil {
|
||||
m.server.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
||||
t.Helper()
|
||||
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
|
||||
require.NoError(t, err)
|
||||
return pubKey
|
||||
}
|
||||
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
t.Helper()
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
t.Helper()
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64.RawURLEncoding.EncodeToString(n),
|
||||
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
t.Helper()
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
@@ -9,7 +9,6 @@ import (
|
||||
"net"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
@@ -21,15 +20,24 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
)
|
||||
|
||||
// TestMain handles package-level setup and cleanup
|
||||
func TestMain(m *testing.M) {
|
||||
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||
// This happens when running tests as non-privileged user with fallback
|
||||
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||
// Just exit with error to break the recursion
|
||||
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
// Run tests
|
||||
code := m.Run()
|
||||
|
||||
// Cleanup any created test users
|
||||
cleanupTestUsers()
|
||||
testutil.CleanupTestUsers()
|
||||
|
||||
os.Exit(code)
|
||||
}
|
||||
@@ -50,13 +58,15 @@ func TestSSHServerCompatibility(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate OpenSSH-compatible keys for client
|
||||
clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey(t)
|
||||
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKeyOpenSSH))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -73,7 +83,7 @@ func TestSSHServerCompatibility(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get appropriate user for SSH connection (handle system accounts)
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
t.Run("basic command execution", func(t *testing.T) {
|
||||
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
|
||||
@@ -113,7 +123,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username
|
||||
// testSSHInteractiveCommand tests interactive shell session.
|
||||
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
@@ -178,7 +188,7 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||
// testSSHPortForwarding tests port forwarding compatibility.
|
||||
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
require.NoError(t, err)
|
||||
@@ -401,7 +411,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
t.Skip("Skipping SSH feature compatibility tests in short mode")
|
||||
}
|
||||
|
||||
if runtime.GOOS == "windows" && isCI() {
|
||||
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
|
||||
}
|
||||
|
||||
@@ -438,13 +448,13 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -468,7 +478,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
||||
// testCommandWithFlags tests that commands with flags work properly
|
||||
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test ls with flags
|
||||
cmd := exec.Command("ssh",
|
||||
@@ -495,7 +505,7 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||
// testEnvironmentVariables tests that environment is properly set up
|
||||
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
cmd := exec.Command("ssh",
|
||||
"-i", keyFile,
|
||||
@@ -522,7 +532,7 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||
// testExitCodes tests that exit codes are properly handled
|
||||
func testExitCodes(t *testing.T, host, port, keyFile string) {
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Test successful command (exit code 0)
|
||||
cmd := exec.Command("ssh",
|
||||
@@ -567,7 +577,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server with specific security settings
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
@@ -575,13 +585,13 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -652,7 +662,7 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
}
|
||||
|
||||
// Get appropriate user for SSH connection
|
||||
username := getTestUsername(t)
|
||||
username := testutil.GetTestUsername(t)
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
@@ -660,13 +670,13 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -710,171 +720,3 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
||||
t.Logf("Platform command output: %s", outputStr)
|
||||
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
|
||||
}
|
||||
|
||||
// getTestUsername returns an appropriate username for testing
|
||||
func getTestUsername(t *testing.T) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
// Check if this is a system account that can't authenticate
|
||||
if isSystemAccount(currentUser.Username) {
|
||||
// In CI environments, create a test user; otherwise try Administrator
|
||||
if isCI() {
|
||||
if testUser := getOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
} else {
|
||||
// Try Administrator first for local development
|
||||
if _, err := user.Lookup("Administrator"); err == nil {
|
||||
return "Administrator"
|
||||
}
|
||||
if testUser := getOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
// isCI checks if we're running in a CI environment
|
||||
func isCI() bool {
|
||||
// Check standard CI environment variables
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
|
||||
return true
|
||||
}
|
||||
|
||||
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
|
||||
hostname, err := os.Hostname()
|
||||
if err == nil && strings.HasPrefix(hostname, "runner") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// isSystemAccount checks if the user is a system account that can't authenticate
|
||||
func isSystemAccount(username string) bool {
|
||||
systemAccounts := []string{
|
||||
"system",
|
||||
"NT AUTHORITY\\SYSTEM",
|
||||
"NT AUTHORITY\\LOCAL SERVICE",
|
||||
"NT AUTHORITY\\NETWORK SERVICE",
|
||||
}
|
||||
|
||||
for _, sysAccount := range systemAccounts {
|
||||
if strings.EqualFold(username, sysAccount) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
var compatTestCreatedUsers = make(map[string]bool)
|
||||
var compatTestUsersToCleanup []string
|
||||
|
||||
// registerTestUserCleanup registers a test user for cleanup
|
||||
func registerTestUserCleanup(username string) {
|
||||
if !compatTestCreatedUsers[username] {
|
||||
compatTestCreatedUsers[username] = true
|
||||
compatTestUsersToCleanup = append(compatTestUsersToCleanup, username)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupTestUsers removes all created test users
|
||||
func cleanupTestUsers() {
|
||||
for _, username := range compatTestUsersToCleanup {
|
||||
removeWindowsTestUser(username)
|
||||
}
|
||||
compatTestUsersToCleanup = nil
|
||||
compatTestCreatedUsers = make(map[string]bool)
|
||||
}
|
||||
|
||||
// getOrCreateTestUser creates a test user on Windows if needed
|
||||
func getOrCreateTestUser(t *testing.T) string {
|
||||
testUsername := "netbird-test-user"
|
||||
|
||||
// Check if user already exists
|
||||
if _, err := user.Lookup(testUsername); err == nil {
|
||||
return testUsername
|
||||
}
|
||||
|
||||
// Try to create the user using PowerShell
|
||||
if createWindowsTestUser(t, testUsername) {
|
||||
// Register cleanup for the test user
|
||||
registerTestUserCleanup(testUsername)
|
||||
return testUsername
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// removeWindowsTestUser removes a local user on Windows using PowerShell
|
||||
func removeWindowsTestUser(username string) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
// PowerShell command to remove a local user
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
Remove-LocalUser -Name "%s" -ErrorAction Stop
|
||||
Write-Output "User removed successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*cannot be found*") {
|
||||
Write-Output "User not found (already removed)"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
}
|
||||
}
|
||||
`, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
|
||||
} else {
|
||||
log.Printf("Test user %s cleanup result: %s", username, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// createWindowsTestUser creates a local user on Windows using PowerShell
|
||||
func createWindowsTestUser(t *testing.T, username string) bool {
|
||||
if runtime.GOOS != "windows" {
|
||||
return false
|
||||
}
|
||||
|
||||
// PowerShell command to create a local user
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
|
||||
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
|
||||
Add-LocalGroupMember -Group "Users" -Member "%s"
|
||||
Write-Output "User created successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*already exists*") {
|
||||
Write-Output "User already exists"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
`, username, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
|
||||
return false
|
||||
}
|
||||
|
||||
t.Logf("Test user creation result: %s", string(output))
|
||||
return true
|
||||
}
|
||||
|
||||
610
client/ssh/server/jwt_test.go
Normal file
610
client/ssh/server/jwt_test.go
Normal file
@@ -0,0 +1,610 @@
|
||||
package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/rand"
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"math/big"
|
||||
"net"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strconv"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/golang-jwt/jwt/v5"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
)
|
||||
|
||||
func TestJWTEnforcement(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT enforcement tests in short mode")
|
||||
}
|
||||
|
||||
// Set up SSH server
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
t.Run("blocks_without_jwt", func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: "test-issuer",
|
||||
Audience: "test-audience",
|
||||
KeysLocation: "test-keys",
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
if err != nil {
|
||||
t.Logf("Detection failed: %v", err)
|
||||
}
|
||||
t.Logf("Detected server type: %s", serverType)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
|
||||
})
|
||||
|
||||
t.Run("allows_when_disabled", func(t *testing.T) {
|
||||
serverConfigNoJWT := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
serverNoJWT := New(serverConfigNoJWT)
|
||||
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
|
||||
serverNoJWT.SetAllowRootLogin(true)
|
||||
|
||||
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||
defer require.NoError(t, serverNoJWT.Stop())
|
||||
|
||||
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||
require.NoError(t, err)
|
||||
portNoJWT, err := strconv.Atoi(portStrNoJWT)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
|
||||
assert.False(t, serverType.RequiresJWT())
|
||||
|
||||
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
|
||||
require.NoError(t, err)
|
||||
defer client.Close()
|
||||
})
|
||||
|
||||
}
|
||||
|
||||
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
|
||||
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||
privateKey, jwksJSON := generateTestJWKS(t)
|
||||
|
||||
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Set("Content-Type", "application/json")
|
||||
if _, err := w.Write(jwksJSON); err != nil {
|
||||
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||
}
|
||||
}))
|
||||
|
||||
return server, privateKey, server.URL
|
||||
}
|
||||
|
||||
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
|
||||
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||
require.NoError(t, err)
|
||||
|
||||
publicKey := &privateKey.PublicKey
|
||||
n := publicKey.N.Bytes()
|
||||
e := publicKey.E
|
||||
|
||||
jwk := nbjwt.JSONWebKey{
|
||||
Kty: "RSA",
|
||||
Kid: "test-key-id",
|
||||
Use: "sig",
|
||||
N: base64RawURLEncode(n),
|
||||
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
|
||||
}
|
||||
|
||||
jwks := nbjwt.Jwks{
|
||||
Keys: []nbjwt.JSONWebKey{jwk},
|
||||
}
|
||||
|
||||
jwksJSON, err := json.Marshal(jwks)
|
||||
require.NoError(t, err)
|
||||
|
||||
return privateKey, jwksJSON
|
||||
}
|
||||
|
||||
func base64RawURLEncode(data []byte) string {
|
||||
return base64.RawURLEncoding.EncodeToString(data)
|
||||
}
|
||||
|
||||
// generateValidJWT creates a valid JWT token for testing
|
||||
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||
claims := jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
}
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
return tokenString
|
||||
}
|
||||
|
||||
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
|
||||
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
|
||||
t.Helper()
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
|
||||
ctx := context.Background()
|
||||
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
}
|
||||
|
||||
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
|
||||
func TestJWTDetection(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT detection test in short mode")
|
||||
}
|
||||
|
||||
jwksServer, _, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
port, err := strconv.Atoi(portStr)
|
||||
require.NoError(t, err)
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
|
||||
assert.True(t, serverType.RequiresJWT())
|
||||
}
|
||||
|
||||
func TestJWTFailClose(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT fail-close tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
tokenClaims jwt.MapClaims
|
||||
}{
|
||||
{
|
||||
name: "blocks_token_missing_iat",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_sub",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_iss",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_missing_aud",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_issuer",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": "wrong-issuer",
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_token_wrong_audience",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": "wrong-audience",
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(time.Hour).Unix(),
|
||||
"iat": time.Now().Unix(),
|
||||
},
|
||||
},
|
||||
{
|
||||
name: "blocks_expired_token",
|
||||
tokenClaims: jwt.MapClaims{
|
||||
"iss": issuer,
|
||||
"aud": audience,
|
||||
"sub": "test-user",
|
||||
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
MaxTokenAge: 3600,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
|
||||
token.Header["kid"] = "test-key-id"
|
||||
tokenString, err := token.SignedString(privateKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.Password(tokenString),
|
||||
},
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
assert.Error(t, err, "Authentication should fail (fail-close)")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
|
||||
func TestJWTAuthentication(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("Skipping JWT authentication tests in short mode")
|
||||
}
|
||||
|
||||
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||
defer jwksServer.Close()
|
||||
|
||||
const (
|
||||
issuer = "https://test-issuer.example.com"
|
||||
audience = "test-audience"
|
||||
)
|
||||
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
token string
|
||||
wantAuthOK bool
|
||||
setupServer func(*Server)
|
||||
testOperation func(*testing.T, *cryptossh.Client, string) error
|
||||
wantOpSuccess bool
|
||||
}{
|
||||
{
|
||||
name: "allows_shell_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
return session.Shell()
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "rejects_invalid_token",
|
||||
token: "invalid",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_shell_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("echo test")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "blocks_command_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
output, err := session.CombinedOutput("ls")
|
||||
if err != nil {
|
||||
t.Logf("Command output: %s", string(output))
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_sftp_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
return session.RequestSubsystem("sftp")
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_sftp_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowSFTP(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
session, err := conn.NewSession()
|
||||
require.NoError(t, err)
|
||||
defer session.Close()
|
||||
|
||||
session.Stdout = io.Discard
|
||||
session.Stderr = io.Discard
|
||||
err = session.RequestSubsystem("sftp")
|
||||
if err == nil {
|
||||
err = session.Wait()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
{
|
||||
name: "allows_port_forward_with_jwt",
|
||||
token: "valid",
|
||||
wantAuthOK: true,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowRemotePortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: true,
|
||||
},
|
||||
{
|
||||
name: "blocks_port_forward_without_jwt",
|
||||
token: "",
|
||||
wantAuthOK: false,
|
||||
setupServer: func(s *Server) {
|
||||
s.SetAllowRootLogin(true)
|
||||
s.SetAllowLocalPortForwarding(true)
|
||||
},
|
||||
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||
if ln != nil {
|
||||
defer ln.Close()
|
||||
}
|
||||
return err
|
||||
},
|
||||
wantOpSuccess: false,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
jwtConfig := &JWTConfig{
|
||||
Issuer: issuer,
|
||||
Audience: audience,
|
||||
KeysLocation: jwksURL,
|
||||
}
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: jwtConfig,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
if tc.setupServer != nil {
|
||||
tc.setupServer(server)
|
||||
}
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer require.NoError(t, server.Stop())
|
||||
|
||||
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||
require.NoError(t, err)
|
||||
|
||||
var authMethods []cryptossh.AuthMethod
|
||||
if tc.token == "valid" {
|
||||
token := generateValidJWT(t, privateKey, issuer, audience)
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(token),
|
||||
}
|
||||
} else if tc.token == "invalid" {
|
||||
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
|
||||
authMethods = []cryptossh.AuthMethod{
|
||||
cryptossh.Password(invalidToken),
|
||||
}
|
||||
}
|
||||
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: testutil.GetTestUsername(t),
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||
Timeout: 2 * time.Second,
|
||||
}
|
||||
|
||||
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||
if tc.wantAuthOK {
|
||||
require.NoError(t, err, "JWT authentication should succeed")
|
||||
} else if err != nil {
|
||||
t.Logf("Connection failed as expected: %v", err)
|
||||
return
|
||||
}
|
||||
if conn != nil {
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
t.Logf("close connection: %v", err)
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
err = tc.testOperation(t, conn, serverAddr)
|
||||
if tc.wantOpSuccess {
|
||||
require.NoError(t, err, "Operation should succeed")
|
||||
} else {
|
||||
assert.Error(t, err, "Operation should fail")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -2,18 +2,27 @@ package server
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
gojwt "github.com/golang-jwt/jwt/v5"
|
||||
log "github.com/sirupsen/logrus"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
|
||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||
@@ -27,6 +36,9 @@ const (
|
||||
errExitSession = "exit session error: %v"
|
||||
|
||||
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
||||
|
||||
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
||||
DefaultJWTMaxTokenAge = 5 * 60
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -69,7 +81,6 @@ func (e *UserNotFoundError) Unwrap() error {
|
||||
}
|
||||
|
||||
// safeLogCommand returns a safe representation of the command for logging
|
||||
// Only logs the first argument to avoid leaking sensitive information
|
||||
func safeLogCommand(cmd []string) string {
|
||||
if len(cmd) == 0 {
|
||||
return "<empty>"
|
||||
@@ -80,17 +91,14 @@ func safeLogCommand(cmd []string) string {
|
||||
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
||||
}
|
||||
|
||||
// sshConnectionState tracks the state of an SSH connection
|
||||
type sshConnectionState struct {
|
||||
hasActivePortForward bool
|
||||
username string
|
||||
remoteAddr string
|
||||
}
|
||||
|
||||
// Server is the SSH server implementation
|
||||
type Server struct {
|
||||
sshServer *ssh.Server
|
||||
authorizedKeys map[string]ssh.PublicKey
|
||||
mu sync.RWMutex
|
||||
hostKeyPEM []byte
|
||||
sessions map[SessionKey]ssh.Session
|
||||
@@ -100,30 +108,53 @@ type Server struct {
|
||||
allowRemotePortForwarding bool
|
||||
allowRootLogin bool
|
||||
allowSFTP bool
|
||||
jwtEnabled bool
|
||||
|
||||
netstackNet *netstack.Net
|
||||
|
||||
wgAddress wgaddr.Address
|
||||
ifIdx int
|
||||
|
||||
remoteForwardListeners map[ForwardKey]net.Listener
|
||||
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
|
||||
|
||||
jwtValidator *jwt.Validator
|
||||
jwtExtractor *jwt.ClaimsExtractor
|
||||
jwtConfig *JWTConfig
|
||||
}
|
||||
|
||||
// New creates an SSH server instance with the provided host key
|
||||
func New(hostKeyPEM []byte) *Server {
|
||||
return &Server{
|
||||
type JWTConfig struct {
|
||||
Issuer string
|
||||
Audience string
|
||||
KeysLocation string
|
||||
MaxTokenAge int64
|
||||
}
|
||||
|
||||
// Config contains all SSH server configuration options
|
||||
type Config struct {
|
||||
// JWT authentication configuration. If nil, JWT authentication is disabled
|
||||
JWT *JWTConfig
|
||||
|
||||
// HostKey is the SSH server host key in PEM format
|
||||
HostKeyPEM []byte
|
||||
}
|
||||
|
||||
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||
// If jwtConfig is nil, JWT authentication is disabled
|
||||
func New(config *Config) *Server {
|
||||
s := &Server{
|
||||
mu: sync.RWMutex{},
|
||||
hostKeyPEM: hostKeyPEM,
|
||||
authorizedKeys: make(map[string]ssh.PublicKey),
|
||||
hostKeyPEM: config.HostKeyPEM,
|
||||
sessions: make(map[SessionKey]ssh.Session),
|
||||
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||
jwtEnabled: config.JWT != nil,
|
||||
jwtConfig: config.JWT,
|
||||
}
|
||||
|
||||
return s
|
||||
}
|
||||
|
||||
// Start runs the SSH server, automatically detecting netstack vs standard networking
|
||||
// Does all setup synchronously, then starts serving in a goroutine and returns immediately
|
||||
// Start runs the SSH server
|
||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -139,7 +170,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
|
||||
sshServer, err := s.createSSHServer(ln.Addr())
|
||||
if err != nil {
|
||||
s.cleanupOnError(ln)
|
||||
s.closeListener(ln)
|
||||
return fmt.Errorf("create SSH server: %w", err)
|
||||
}
|
||||
|
||||
@@ -154,7 +185,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// createListener creates a network listener based on netstack vs standard networking
|
||||
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
|
||||
if s.netstackNet != nil {
|
||||
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||
@@ -173,22 +203,15 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L
|
||||
return ln, addr.String(), nil
|
||||
}
|
||||
|
||||
// closeListener safely closes a listener
|
||||
func (s *Server) closeListener(ln net.Listener) {
|
||||
if ln == nil {
|
||||
return
|
||||
}
|
||||
if err := ln.Close(); err != nil {
|
||||
log.Debugf("listener close error: %v", err)
|
||||
}
|
||||
}
|
||||
|
||||
// cleanupOnError cleans up resources when SSH server creation fails
|
||||
func (s *Server) cleanupOnError(ln net.Listener) {
|
||||
if s.ifIdx == 0 || ln == nil {
|
||||
return
|
||||
}
|
||||
|
||||
s.closeListener(ln)
|
||||
}
|
||||
|
||||
// Stop closes the SSH server
|
||||
func (s *Server) Stop() error {
|
||||
s.mu.Lock()
|
||||
@@ -207,28 +230,6 @@ func (s *Server) Stop() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveAuthorizedKey removes the SSH key for a peer
|
||||
func (s *Server) RemoveAuthorizedKey(peer string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
delete(s.authorizedKeys, peer)
|
||||
}
|
||||
|
||||
// AddAuthorizedKey adds an SSH key for a peer
|
||||
func (s *Server) AddAuthorizedKey(peer, newKey string) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse key: %w", err)
|
||||
}
|
||||
|
||||
s.authorizedKeys[peer] = parsedKey
|
||||
return nil
|
||||
}
|
||||
|
||||
// SetNetstackNet sets the netstack network for userspace networking
|
||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||
s.mu.Lock()
|
||||
@@ -243,34 +244,195 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
||||
s.wgAddress = addr
|
||||
}
|
||||
|
||||
// SetSocketFilter configures eBPF socket filtering for the SSH server
|
||||
func (s *Server) SetSocketFilter(ifIdx int) {
|
||||
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||
func (s *Server) ensureJWTValidator() error {
|
||||
s.mu.RLock()
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
s.mu.RUnlock()
|
||||
return nil
|
||||
}
|
||||
config := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
if config == nil {
|
||||
return fmt.Errorf("JWT config not set")
|
||||
}
|
||||
|
||||
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
|
||||
|
||||
validator := jwt.NewValidator(
|
||||
config.Issuer,
|
||||
[]string{config.Audience},
|
||||
config.KeysLocation,
|
||||
true,
|
||||
)
|
||||
|
||||
extractor := jwt.NewClaimsExtractor(
|
||||
jwt.WithAudience(config.Audience),
|
||||
)
|
||||
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.ifIdx = ifIdx
|
||||
|
||||
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
s.jwtValidator = validator
|
||||
s.jwtExtractor = extractor
|
||||
|
||||
log.Infof("JWT validator initialized successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
||||
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
||||
s.mu.RLock()
|
||||
defer s.mu.RUnlock()
|
||||
jwtValidator := s.jwtValidator
|
||||
jwtConfig := s.jwtConfig
|
||||
s.mu.RUnlock()
|
||||
|
||||
for _, allowed := range s.authorizedKeys {
|
||||
if ssh.KeysEqual(allowed, key) {
|
||||
if ctx != nil {
|
||||
log.Debugf("SSH key authentication successful for user %s from %s", ctx.User(), ctx.RemoteAddr())
|
||||
if jwtValidator == nil {
|
||||
return nil, fmt.Errorf("JWT validator not initialized")
|
||||
}
|
||||
|
||||
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
|
||||
if err != nil {
|
||||
if jwtConfig != nil {
|
||||
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
||||
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
|
||||
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
|
||||
}
|
||||
return true
|
||||
}
|
||||
return nil, fmt.Errorf("validate token: %w", err)
|
||||
}
|
||||
|
||||
if ctx != nil {
|
||||
log.Warnf("SSH key authentication failed for user %s from %s: key not authorized (type: %s, fingerprint: %s)",
|
||||
ctx.User(), ctx.RemoteAddr(), key.Type(), cryptossh.FingerprintSHA256(key))
|
||||
if err := s.checkTokenAge(token, jwtConfig); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return false
|
||||
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||
if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
||||
}
|
||||
|
||||
iat, ok := claims["iat"].(float64)
|
||||
if !ok {
|
||||
userID := extractUserID(token)
|
||||
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
||||
}
|
||||
|
||||
issuedAt := time.Unix(int64(iat), 0)
|
||||
tokenAge := time.Since(issuedAt)
|
||||
maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
|
||||
if tokenAge > maxAge {
|
||||
userID := getUserIDFromClaims(claims)
|
||||
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
|
||||
s.mu.RLock()
|
||||
jwtExtractor := s.jwtExtractor
|
||||
s.mu.RUnlock()
|
||||
|
||||
if jwtExtractor == nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
|
||||
}
|
||||
|
||||
userAuth, err := jwtExtractor.ToUserAuth(token)
|
||||
if err != nil {
|
||||
userID := extractUserID(token)
|
||||
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
|
||||
}
|
||||
|
||||
if !s.hasSSHAccess(&userAuth) {
|
||||
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
|
||||
}
|
||||
|
||||
return &userAuth, nil
|
||||
}
|
||||
|
||||
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
|
||||
return userAuth.UserId != ""
|
||||
}
|
||||
|
||||
func extractUserID(token *gojwt.Token) string {
|
||||
if token == nil {
|
||||
return "unknown"
|
||||
}
|
||||
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||
if !ok {
|
||||
return "unknown"
|
||||
}
|
||||
return getUserIDFromClaims(claims)
|
||||
}
|
||||
|
||||
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||
return sub
|
||||
}
|
||||
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||
return userID
|
||||
}
|
||||
if email, ok := claims["email"].(string); ok && email != "" {
|
||||
return email
|
||||
}
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||
parts := strings.Split(tokenString, ".")
|
||||
if len(parts) != 3 {
|
||||
return nil, fmt.Errorf("invalid token format")
|
||||
}
|
||||
|
||||
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("decode payload: %w", err)
|
||||
}
|
||||
|
||||
var claims map[string]interface{}
|
||||
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||
return nil, fmt.Errorf("parse claims: %w", err)
|
||||
}
|
||||
|
||||
return claims, nil
|
||||
}
|
||||
|
||||
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||
if err := s.ensureJWTValidator(); err != nil {
|
||||
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
token, err := s.validateJWTToken(password)
|
||||
if err != nil {
|
||||
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
userAuth, err := s.extractAndValidateUser(token)
|
||||
if err != nil {
|
||||
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||
return false
|
||||
}
|
||||
|
||||
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
||||
return true
|
||||
}
|
||||
|
||||
// markConnectionActivePortForward marks an SSH connection as having an active port forward
|
||||
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -286,14 +448,12 @@ func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn,
|
||||
}
|
||||
}
|
||||
|
||||
// connectionCloseHandler cleans up connection state when SSH connections fail/close
|
||||
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
|
||||
// We can't extract the SSH connection from net.Conn directly
|
||||
// Connection cleanup will happen during session cleanup or via timeout
|
||||
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
|
||||
}
|
||||
|
||||
// findSessionKeyByContext finds the session key by matching SSH connection context
|
||||
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||
if ctx == nil {
|
||||
return "unknown"
|
||||
@@ -319,14 +479,13 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||
// Return a temporary key that we'll fix up later
|
||||
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
||||
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
||||
log.Debugf("using temporary session key for port forward tracking: %s", tempKey)
|
||||
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
|
||||
return tempKey
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
// connectionValidator validates incoming connections based on source IP
|
||||
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
s.mu.RLock()
|
||||
netbirdNetwork := s.wgAddress.Network
|
||||
@@ -340,8 +499,8 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
remoteAddr := conn.RemoteAddr()
|
||||
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
||||
if !ok {
|
||||
log.Debugf("SSH connection from non-TCP address %s allowed", remoteAddr)
|
||||
return conn
|
||||
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
|
||||
return nil
|
||||
}
|
||||
|
||||
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||
@@ -357,15 +516,14 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||
}
|
||||
|
||||
if !netbirdNetwork.Contains(remoteIP) {
|
||||
log.Warnf("SSH connection rejected from non-NetBird IP %s (allowed range: %s)", remoteIP, netbirdNetwork)
|
||||
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("SSH connection from %s allowed", remoteIP)
|
||||
log.Infof("SSH connection from NetBird peer %s allowed", remoteIP)
|
||||
return conn
|
||||
}
|
||||
|
||||
// isShutdownError checks if the error is expected during normal shutdown
|
||||
func isShutdownError(err error) bool {
|
||||
if errors.Is(err, net.ErrClosed) {
|
||||
return true
|
||||
@@ -379,12 +537,16 @@ func isShutdownError(err error) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
// createSSHServer creates and configures the SSH server
|
||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
if err := enableUserSwitching(); err != nil {
|
||||
log.Warnf("failed to enable user switching: %v", err)
|
||||
}
|
||||
|
||||
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
|
||||
if s.jwtEnabled {
|
||||
serverVersion += " " + detection.JWTRequiredMarker
|
||||
}
|
||||
|
||||
server := &ssh.Server{
|
||||
Addr: addr.String(),
|
||||
Handler: s.sessionHandler,
|
||||
@@ -402,6 +564,11 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
},
|
||||
ConnCallback: s.connectionValidator,
|
||||
ConnectionFailedCallback: s.connectionCloseHandler,
|
||||
Version: serverVersion,
|
||||
}
|
||||
|
||||
if s.jwtEnabled {
|
||||
server.PasswordHandler = s.passwordHandler
|
||||
}
|
||||
|
||||
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
|
||||
@@ -413,14 +580,12 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||
return server, nil
|
||||
}
|
||||
|
||||
// storeRemoteForwardListener stores a remote forward listener for cleanup
|
||||
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
s.remoteForwardListeners[key] = ln
|
||||
}
|
||||
|
||||
// removeRemoteForwardListener removes and closes a remote forward listener
|
||||
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
@@ -438,7 +603,6 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// directTCPIPHandler handles direct-tcpip channel requests for local port forwarding with privilege validation
|
||||
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
||||
var payload struct {
|
||||
Host string
|
||||
|
||||
@@ -22,12 +22,6 @@ func TestServer_RootLoginRestriction(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
allowRoot bool
|
||||
@@ -117,10 +111,12 @@ func TestServer_RootLoginRestriction(t *testing.T) {
|
||||
defer cleanup()
|
||||
|
||||
// Create server with specific configuration
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(tt.allowRoot)
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Test the userNameLookup method directly
|
||||
user, err := server.userNameLookup(tt.username)
|
||||
@@ -196,7 +192,11 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
// Create server with specific configuration
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
|
||||
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
|
||||
|
||||
@@ -234,17 +234,13 @@ func TestServer_PortConflictHandling(t *testing.T) {
|
||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
server := New(hostKey)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
serverAddr := StartTestServer(t, server)
|
||||
defer func() {
|
||||
@@ -263,7 +259,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
|
||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel1()
|
||||
|
||||
client1, err := sshclient.DialInsecure(ctx1, serverAddr, currentUser.Username)
|
||||
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client1.Close()
|
||||
@@ -274,7 +272,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
|
||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel2()
|
||||
|
||||
client2, err := sshclient.DialInsecure(ctx2, serverAddr, currentUser.Username)
|
||||
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||
InsecureSkipVerify: true,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
defer func() {
|
||||
err := client2.Close()
|
||||
|
||||
@@ -7,11 +7,9 @@ import (
|
||||
"net/netip"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
@@ -19,82 +17,15 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestServer_AddAuthorizedKey(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
server := New(key)
|
||||
|
||||
keys := map[string][]byte{}
|
||||
for i := 0; i < 10; i++ {
|
||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
||||
require.NoError(t, err)
|
||||
keys[peer] = remotePubKey
|
||||
}
|
||||
|
||||
for peer, remotePubKey := range keys {
|
||||
k, ok := server.authorizedKeys[peer]
|
||||
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
|
||||
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(cryptossh.MarshalAuthorizedKey(k))))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RemoveAuthorizedKey(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
server := New(key)
|
||||
|
||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
server.RemoveAuthorizedKey("remotePeer")
|
||||
|
||||
_, ok := server.authorizedKeys["remotePeer"]
|
||||
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
|
||||
}
|
||||
|
||||
func TestServer_PubKeyHandler(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
server := New(key)
|
||||
|
||||
var keys []ssh.PublicKey
|
||||
for i := 0; i < 10; i++ {
|
||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
||||
require.NoError(t, err)
|
||||
keys = append(keys, remoteParsedPubKey)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
accepted := server.publicKeyHandler(nil, key)
|
||||
assert.True(t, accepted, "SSH key should be accepted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_StartStop(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(key)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: key,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
@@ -108,15 +39,13 @@ func TestSSHServerIntegration(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with random port
|
||||
server := New(hostKey)
|
||||
|
||||
// Add client's public key as authorized
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server in background
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -212,13 +141,13 @@ func TestSSHServerMultipleConnections(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
server := New(hostKey)
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -324,20 +253,12 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate authorized key
|
||||
authorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
authorizedPubKey, err := nbssh.GeneratePublicKey(authorizedPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate unauthorized key (different from authorized)
|
||||
unauthorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with only one authorized key
|
||||
server := New(hostKey)
|
||||
err = server.AddAuthorizedKey("authorized-peer", string(authorizedPubKey))
|
||||
require.NoError(t, err)
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -377,8 +298,10 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse unauthorized private key
|
||||
unauthorizedSigner, err := cryptossh.ParsePrivateKey(unauthorizedPrivKey)
|
||||
// Generate a client private key for SSH protocol (server doesn't check it)
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
@@ -390,17 +313,17 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
// Try to connect with unauthorized key
|
||||
// Try to connect with client key
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(unauthorizedSigner),
|
||||
cryptossh.PublicKeys(clientSigner),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// This should succeed in no-auth mode
|
||||
// This should succeed in no-auth mode (server doesn't verify keys)
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
assert.NoError(t, err, "Connection should succeed in no-auth mode")
|
||||
if conn != nil {
|
||||
@@ -412,7 +335,11 @@ func TestSSHServerStartStopCycle(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
serverAddr := "127.0.0.1:0"
|
||||
|
||||
// Test multiple start/stop cycles
|
||||
@@ -485,8 +412,17 @@ func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server1 := New(hostKey)
|
||||
server2 := New(hostKey)
|
||||
serverConfig1 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server1 := New(serverConfig1)
|
||||
|
||||
serverConfig2 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server2 := New(serverConfig2)
|
||||
|
||||
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
|
||||
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
|
||||
|
||||
@@ -35,17 +35,15 @@ func TestSSHServer_SFTPSubsystem(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP enabled
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(true)
|
||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
||||
|
||||
// Add client's public key as authorized
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
server.SetAllowRootLogin(true)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -144,17 +142,15 @@ func TestSSHServer_SFTPDisabled(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with SFTP disabled
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
server.SetAllowSFTP(false)
|
||||
|
||||
// Add client's public key as authorized
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
started := make(chan string, 1)
|
||||
|
||||
@@ -14,7 +14,6 @@ func StartTestServer(t *testing.T, server *Server) string {
|
||||
errChan := make(chan error, 1)
|
||||
|
||||
go func() {
|
||||
// Get a free port
|
||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||
if err != nil {
|
||||
errChan <- err
|
||||
@@ -26,9 +25,12 @@ func StartTestServer(t *testing.T, server *Server) string {
|
||||
return
|
||||
}
|
||||
|
||||
started <- actualAddr
|
||||
addrPort := netip.MustParseAddrPort(actualAddr)
|
||||
errChan <- server.Start(context.Background(), addrPort)
|
||||
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||
errChan <- err
|
||||
return
|
||||
}
|
||||
started <- actualAddr
|
||||
}()
|
||||
|
||||
select {
|
||||
|
||||
172
client/ssh/testutil/user_helpers.go
Normal file
172
client/ssh/testutil/user_helpers.go
Normal file
@@ -0,0 +1,172 @@
|
||||
package testutil
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"log"
|
||||
"os"
|
||||
"os/exec"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
var testCreatedUsers = make(map[string]bool)
|
||||
var testUsersToCleanup []string
|
||||
|
||||
// GetTestUsername returns an appropriate username for testing
|
||||
func GetTestUsername(t *testing.T) string {
|
||||
if runtime.GOOS == "windows" {
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
|
||||
if IsSystemAccount(currentUser.Username) {
|
||||
if IsCI() {
|
||||
if testUser := GetOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
} else {
|
||||
if _, err := user.Lookup("Administrator"); err == nil {
|
||||
return "Administrator"
|
||||
}
|
||||
if testUser := GetOrCreateTestUser(t); testUser != "" {
|
||||
return testUser
|
||||
}
|
||||
}
|
||||
}
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user")
|
||||
return currentUser.Username
|
||||
}
|
||||
|
||||
// IsCI checks if we're running in a CI environment
|
||||
func IsCI() bool {
|
||||
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
|
||||
return true
|
||||
}
|
||||
|
||||
hostname, err := os.Hostname()
|
||||
if err == nil && strings.HasPrefix(hostname, "runner") {
|
||||
return true
|
||||
}
|
||||
|
||||
return false
|
||||
}
|
||||
|
||||
// IsSystemAccount checks if the user is a system account that can't authenticate
|
||||
func IsSystemAccount(username string) bool {
|
||||
systemAccounts := []string{
|
||||
"system",
|
||||
"NT AUTHORITY\\SYSTEM",
|
||||
"NT AUTHORITY\\LOCAL SERVICE",
|
||||
"NT AUTHORITY\\NETWORK SERVICE",
|
||||
}
|
||||
|
||||
for _, sysAccount := range systemAccounts {
|
||||
if strings.EqualFold(username, sysAccount) {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// RegisterTestUserCleanup registers a test user for cleanup
|
||||
func RegisterTestUserCleanup(username string) {
|
||||
if !testCreatedUsers[username] {
|
||||
testCreatedUsers[username] = true
|
||||
testUsersToCleanup = append(testUsersToCleanup, username)
|
||||
}
|
||||
}
|
||||
|
||||
// CleanupTestUsers removes all created test users
|
||||
func CleanupTestUsers() {
|
||||
for _, username := range testUsersToCleanup {
|
||||
RemoveWindowsTestUser(username)
|
||||
}
|
||||
testUsersToCleanup = nil
|
||||
testCreatedUsers = make(map[string]bool)
|
||||
}
|
||||
|
||||
// GetOrCreateTestUser creates a test user on Windows if needed
|
||||
func GetOrCreateTestUser(t *testing.T) string {
|
||||
testUsername := "netbird-test-user"
|
||||
|
||||
if _, err := user.Lookup(testUsername); err == nil {
|
||||
return testUsername
|
||||
}
|
||||
|
||||
if CreateWindowsTestUser(t, testUsername) {
|
||||
RegisterTestUserCleanup(testUsername)
|
||||
return testUsername
|
||||
}
|
||||
|
||||
return ""
|
||||
}
|
||||
|
||||
// RemoveWindowsTestUser removes a local user on Windows using PowerShell
|
||||
func RemoveWindowsTestUser(username string) {
|
||||
if runtime.GOOS != "windows" {
|
||||
return
|
||||
}
|
||||
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
Remove-LocalUser -Name "%s" -ErrorAction Stop
|
||||
Write-Output "User removed successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*cannot be found*") {
|
||||
Write-Output "User not found (already removed)"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
}
|
||||
}
|
||||
`, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
|
||||
} else {
|
||||
log.Printf("Test user %s cleanup result: %s", username, string(output))
|
||||
}
|
||||
}
|
||||
|
||||
// CreateWindowsTestUser creates a local user on Windows using PowerShell
|
||||
func CreateWindowsTestUser(t *testing.T, username string) bool {
|
||||
if runtime.GOOS != "windows" {
|
||||
return false
|
||||
}
|
||||
|
||||
psCmd := fmt.Sprintf(`
|
||||
try {
|
||||
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
|
||||
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
|
||||
Add-LocalGroupMember -Group "Users" -Member "%s"
|
||||
Write-Output "User created successfully"
|
||||
} catch {
|
||||
if ($_.Exception.Message -like "*already exists*") {
|
||||
Write-Output "User already exists"
|
||||
} else {
|
||||
Write-Error $_.Exception.Message
|
||||
exit 1
|
||||
}
|
||||
}
|
||||
`, username, username)
|
||||
|
||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||
output, err := cmd.CombinedOutput()
|
||||
|
||||
if err != nil {
|
||||
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
|
||||
return false
|
||||
}
|
||||
|
||||
t.Logf("Test user creation result: %s", string(output))
|
||||
return true
|
||||
}
|
||||
Reference in New Issue
Block a user