mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 18:26:41 +00:00
Add ssh authenatication with jwt (#4550)
This commit is contained in:
@@ -11,6 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||
@@ -125,10 +126,15 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
||||
username = args[2].String()
|
||||
}
|
||||
|
||||
var jwtToken string
|
||||
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
|
||||
jwtToken = args[3].String()
|
||||
}
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
sshClient := ssh.NewClient(client)
|
||||
|
||||
if err := sshClient.Connect(host, port, username); err != nil {
|
||||
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
@@ -191,12 +197,43 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||
}))
|
||||
}
|
||||
|
||||
// createDetectSSHServerMethod creates the SSH server detection method
|
||||
func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||
if len(args) < 2 {
|
||||
return js.ValueOf("error: requires host and port")
|
||||
}
|
||||
|
||||
host := args[0].String()
|
||||
port := args[1].Int()
|
||||
|
||||
return createPromise(func(resolve, reject js.Value) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||
defer cancel()
|
||||
|
||||
serverType, err := detectSSHServerType(ctx, client, host, port)
|
||||
if err != nil {
|
||||
reject.Invoke(err.Error())
|
||||
return
|
||||
}
|
||||
|
||||
resolve.Invoke(js.ValueOf(serverType.RequiresJWT()))
|
||||
})
|
||||
})
|
||||
}
|
||||
|
||||
// detectSSHServerType detects SSH server type using NetBird network connection
|
||||
func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) {
|
||||
return sshdetection.DetectSSHServerType(ctx, client, host, port)
|
||||
}
|
||||
|
||||
// createClientObject wraps the NetBird client in a JavaScript object
|
||||
func createClientObject(client *netbird.Client) js.Value {
|
||||
obj := make(map[string]interface{})
|
||||
|
||||
obj["start"] = createStartMethod(client)
|
||||
obj["stop"] = createStopMethod(client)
|
||||
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
||||
obj["createSSHConnection"] = createSSHMethod(client)
|
||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||
|
||||
@@ -13,6 +13,7 @@ import (
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
netbird "github.com/netbirdio/netbird/client/embed"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
@@ -45,34 +46,19 @@ func NewClient(nbClient *netbird.Client) *Client {
|
||||
}
|
||||
|
||||
// Connect establishes an SSH connection through NetBird network
|
||||
func (c *Client) Connect(host string, port int, username string) error {
|
||||
func (c *Client) Connect(host string, port int, username, jwtToken string) error {
|
||||
addr := fmt.Sprintf("%s:%d", host, port)
|
||||
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
||||
|
||||
var authMethods []ssh.AuthMethod
|
||||
|
||||
nbConfig, err := c.nbClient.GetConfig()
|
||||
authMethods, err := c.getAuthMethods(jwtToken)
|
||||
if err != nil {
|
||||
return fmt.Errorf("get NetBird config: %w", err)
|
||||
return err
|
||||
}
|
||||
if nbConfig.SSHKey == "" {
|
||||
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
|
||||
}
|
||||
|
||||
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse NetBird SSH private key: %w", err)
|
||||
}
|
||||
|
||||
pubKey := signer.PublicKey()
|
||||
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
|
||||
|
||||
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
||||
|
||||
config := &ssh.ClientConfig{
|
||||
User: username,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
||||
HostKeyCallback: nbssh.CreateHostKeyCallback(c.nbClient),
|
||||
Timeout: sshDialTimeout,
|
||||
}
|
||||
|
||||
@@ -96,6 +82,33 @@ func (c *Client) Connect(host string, port int, username string) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// getAuthMethods returns SSH authentication methods, preferring JWT if available
|
||||
func (c *Client) getAuthMethods(jwtToken string) ([]ssh.AuthMethod, error) {
|
||||
if jwtToken != "" {
|
||||
logrus.Debugf("SSH: Using JWT password authentication")
|
||||
return []ssh.AuthMethod{ssh.Password(jwtToken)}, nil
|
||||
}
|
||||
|
||||
logrus.Debugf("SSH: No JWT token, using public key authentication")
|
||||
|
||||
nbConfig, err := c.nbClient.GetConfig()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("get NetBird config: %w", err)
|
||||
}
|
||||
|
||||
if nbConfig.SSHKey == "" {
|
||||
return nil, fmt.Errorf("no NetBird SSH key available")
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey([]byte(nbConfig.SSHKey))
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("parse NetBird SSH private key: %w", err)
|
||||
}
|
||||
|
||||
logrus.Debugf("SSH: Added public key auth")
|
||||
return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
|
||||
}
|
||||
|
||||
// StartSession starts an SSH session with PTY
|
||||
func (c *Client) StartSession(cols, rows int) error {
|
||||
if c.sshClient == nil {
|
||||
|
||||
@@ -1,50 +0,0 @@
|
||||
//go:build js
|
||||
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"crypto/x509"
|
||||
"encoding/pem"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
)
|
||||
|
||||
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
|
||||
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
|
||||
keyStr := string(keyPEM)
|
||||
if !strings.Contains(keyStr, "-----BEGIN") {
|
||||
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
|
||||
}
|
||||
|
||||
signer, err := ssh.ParsePrivateKey(keyPEM)
|
||||
if err == nil {
|
||||
return signer, nil
|
||||
}
|
||||
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
|
||||
|
||||
block, _ := pem.Decode(keyPEM)
|
||||
if block == nil {
|
||||
keyPreview := string(keyPEM)
|
||||
if len(keyPreview) > 100 {
|
||||
keyPreview = keyPreview[:100]
|
||||
}
|
||||
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
|
||||
}
|
||||
|
||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
||||
if err != nil {
|
||||
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
|
||||
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
||||
return ssh.NewSignerFromKey(rsaKey)
|
||||
}
|
||||
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
||||
return ssh.NewSignerFromKey(ecKey)
|
||||
}
|
||||
return nil, fmt.Errorf("parse private key: %w", err)
|
||||
}
|
||||
|
||||
return ssh.NewSignerFromKey(key)
|
||||
}
|
||||
Reference in New Issue
Block a user