Add ssh authenatication with jwt (#4550)

This commit is contained in:
Viktor Liu
2025-10-07 23:38:27 +02:00
committed by GitHub
parent 7e0bbaaa3c
commit d9efe4e944
50 changed files with 4429 additions and 2336 deletions

View File

@@ -5,10 +5,12 @@ import (
"errors"
"flag"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"slices"
"strconv"
"strings"
"syscall"
@@ -16,6 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/util"
)
@@ -29,6 +33,7 @@ const (
enableSSHSFTPFlag = "enable-ssh-sftp"
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
)
var (
@@ -41,6 +46,7 @@ var (
strictHostKeyChecking bool
knownHostsFile string
identityFile string
skipCachedToken bool
)
var (
@@ -49,6 +55,7 @@ var (
enableSSHSFTP bool
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
)
func init() {
@@ -57,6 +64,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
@@ -64,11 +72,14 @@ func init() {
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
sshCmd.AddCommand(sshSftpCmd)
sshCmd.AddCommand(sshProxyCmd)
sshCmd.AddCommand(sshDetectCmd)
}
var sshCmd = &cobra.Command{
@@ -335,31 +346,51 @@ func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers
}
// createSSHFlagSet creates and configures the flag set for SSH command parsing
func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string, *string, *string, *string) {
// sshFlags contains all SSH-related flags and parameters
type sshFlags struct {
Port int
Username string
Login string
StrictHostKeyChecking bool
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
ConfigPath string
LogLevel string
LocalForwards []string
RemoteForwards []string
Host string
Command string
}
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil)
portFlag := fs.Int("p", sshserver.DefaultSSHPort, "SSH port")
flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
userFlag := fs.String("u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.String("user", "", sshUsernameDesc)
loginFlag := fs.String("login", "", sshUsernameDesc+" (alias for --user)")
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
strictHostKeyCheckingFlag := fs.Bool("strict-host-key-checking", true, "Enable strict host key checking")
knownHostsFlag := fs.String("o", "", "Path to known_hosts file")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.String("known-hosts", "", "Path to known_hosts file")
identityFlag := fs.String("i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.String("identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
configFlag := fs.String("c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.String("config", defaultConfigPath, "Netbird config file location")
logLevelFlag := fs.String("l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
return fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag
return fs, flags
}
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
@@ -375,7 +406,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag := createSSHFlagSet()
fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil {
return parseHostnameAndCommand(filteredArgs)
@@ -386,22 +417,23 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
return errors.New(hostArgumentRequired)
}
port = *portFlag
if *userFlag != "" {
username = *userFlag
} else if *loginFlag != "" {
username = *loginFlag
port = flags.Port
if flags.Username != "" {
username = flags.Username
} else if flags.Login != "" {
username = flags.Login
}
strictHostKeyChecking = *strictHostKeyCheckingFlag
knownHostsFile = *knownHostsFlag
identityFile = *identityFlag
strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken
if *configFlag != getEnvOrDefault("CONFIG", configPath) {
configPath = *configFlag
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath
}
if *logLevelFlag != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = *logLevelFlag
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = flags.LogLevel
}
localForwards = localForwardFlags
@@ -449,30 +481,20 @@ func parseHostnameAndCommand(args []string) error {
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
var c *sshclient.Client
var err error
if strictHostKeyChecking {
c, err = sshclient.DialWithOptions(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
})
} else {
c, err = sshclient.DialInsecure(ctx, target, username)
}
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking,
})
if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
if strictHostKeyChecking {
cmd.Printf(" 4. Try --strict-host-key-checking=false to bypass host key verification\n")
}
cmd.Printf("\n")
return fmt.Errorf("dial %s: %w", target, err)
}
@@ -665,3 +687,65 @@ func normalizeLocalHost(host string) string {
}
return host
}
var sshProxyCmd = &cobra.Command{
Use: "proxy <host> <port>",
Short: "Internal SSH proxy for native SSH client integration",
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshProxyFn,
}
func sshProxyFn(cmd *cobra.Command, args []string) error {
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port: %s", portStr)
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}
if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err)
}
return nil
}
var sshDetectCmd = &cobra.Command{
Use: "detect <host> <port>",
Short: "Detect if a host is running NetBird SSH",
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshDetectFn,
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
if err := util.InitLog(logLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
os.Exit(serverType.ExitCode())
return nil
}