mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 09:16:40 +00:00
Add ssh authenatication with jwt (#4550)
This commit is contained in:
@@ -7,11 +7,9 @@ import (
|
||||
"net/netip"
|
||||
"os/user"
|
||||
"runtime"
|
||||
"strings"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/gliderlabs/ssh"
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
cryptossh "golang.org/x/crypto/ssh"
|
||||
@@ -19,82 +17,15 @@ import (
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestServer_AddAuthorizedKey(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
server := New(key)
|
||||
|
||||
keys := map[string][]byte{}
|
||||
for i := 0; i < 10; i++ {
|
||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
||||
require.NoError(t, err)
|
||||
keys[peer] = remotePubKey
|
||||
}
|
||||
|
||||
for peer, remotePubKey := range keys {
|
||||
k, ok := server.authorizedKeys[peer]
|
||||
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
|
||||
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(cryptossh.MarshalAuthorizedKey(k))))
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_RemoveAuthorizedKey(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
server := New(key)
|
||||
|
||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
|
||||
require.NoError(t, err)
|
||||
|
||||
server.RemoveAuthorizedKey("remotePeer")
|
||||
|
||||
_, ok := server.authorizedKeys["remotePeer"]
|
||||
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
|
||||
}
|
||||
|
||||
func TestServer_PubKeyHandler(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
server := New(key)
|
||||
|
||||
var keys []ssh.PublicKey
|
||||
for i := 0; i < 10; i++ {
|
||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
||||
require.NoError(t, err)
|
||||
keys = append(keys, remoteParsedPubKey)
|
||||
}
|
||||
|
||||
for _, key := range keys {
|
||||
accepted := server.publicKeyHandler(nil, key)
|
||||
assert.True(t, accepted, "SSH key should be accepted")
|
||||
}
|
||||
}
|
||||
|
||||
func TestServer_StartStop(t *testing.T) {
|
||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(key)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: key,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
err = server.Stop()
|
||||
assert.NoError(t, err)
|
||||
@@ -108,15 +39,13 @@ func TestSSHServerIntegration(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with random port
|
||||
server := New(hostKey)
|
||||
|
||||
// Add client's public key as authorized
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server in background
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -212,13 +141,13 @@ func TestSSHServerMultipleConnections(t *testing.T) {
|
||||
// Generate client key pair
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server
|
||||
server := New(hostKey)
|
||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
||||
require.NoError(t, err)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -324,20 +253,12 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate authorized key
|
||||
authorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
authorizedPubKey, err := nbssh.GeneratePublicKey(authorizedPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate unauthorized key (different from authorized)
|
||||
unauthorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create server with only one authorized key
|
||||
server := New(hostKey)
|
||||
err = server.AddAuthorizedKey("authorized-peer", string(authorizedPubKey))
|
||||
require.NoError(t, err)
|
||||
// Create server
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
|
||||
// Start server
|
||||
serverAddr := "127.0.0.1:0"
|
||||
@@ -377,8 +298,10 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
require.NoError(t, err)
|
||||
}()
|
||||
|
||||
// Parse unauthorized private key
|
||||
unauthorizedSigner, err := cryptossh.ParsePrivateKey(unauthorizedPrivKey)
|
||||
// Generate a client private key for SSH protocol (server doesn't check it)
|
||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Parse server host key
|
||||
@@ -390,17 +313,17 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
||||
currentUser, err := user.Current()
|
||||
require.NoError(t, err, "Should be able to get current user for test")
|
||||
|
||||
// Try to connect with unauthorized key
|
||||
// Try to connect with client key
|
||||
config := &cryptossh.ClientConfig{
|
||||
User: currentUser.Username,
|
||||
Auth: []cryptossh.AuthMethod{
|
||||
cryptossh.PublicKeys(unauthorizedSigner),
|
||||
cryptossh.PublicKeys(clientSigner),
|
||||
},
|
||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||
Timeout: 3 * time.Second,
|
||||
}
|
||||
|
||||
// This should succeed in no-auth mode
|
||||
// This should succeed in no-auth mode (server doesn't verify keys)
|
||||
conn, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||
assert.NoError(t, err, "Connection should succeed in no-auth mode")
|
||||
if conn != nil {
|
||||
@@ -412,7 +335,11 @@ func TestSSHServerStartStopCycle(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server := New(hostKey)
|
||||
serverConfig := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server := New(serverConfig)
|
||||
serverAddr := "127.0.0.1:0"
|
||||
|
||||
// Test multiple start/stop cycles
|
||||
@@ -485,8 +412,17 @@ func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
|
||||
server1 := New(hostKey)
|
||||
server2 := New(hostKey)
|
||||
serverConfig1 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server1 := New(serverConfig1)
|
||||
|
||||
serverConfig2 := &Config{
|
||||
HostKeyPEM: hostKey,
|
||||
JWT: nil,
|
||||
}
|
||||
server2 := New(serverConfig2)
|
||||
|
||||
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
|
||||
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
|
||||
|
||||
Reference in New Issue
Block a user