mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-21 09:46:40 +00:00
Add ssh authenatication with jwt (#4550)
This commit is contained in:
@@ -5,10 +5,12 @@ import (
|
||||
"errors"
|
||||
"flag"
|
||||
"fmt"
|
||||
"net"
|
||||
"os"
|
||||
"os/signal"
|
||||
"os/user"
|
||||
"slices"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
@@ -16,6 +18,8 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
|
||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
@@ -29,6 +33,7 @@ const (
|
||||
enableSSHSFTPFlag = "enable-ssh-sftp"
|
||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||
disableSSHAuthFlag = "disable-ssh-auth"
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -41,6 +46,7 @@ var (
|
||||
strictHostKeyChecking bool
|
||||
knownHostsFile string
|
||||
identityFile string
|
||||
skipCachedToken bool
|
||||
)
|
||||
|
||||
var (
|
||||
@@ -49,6 +55,7 @@ var (
|
||||
enableSSHSFTP bool
|
||||
enableSSHLocalPortForward bool
|
||||
enableSSHRemotePortForward bool
|
||||
disableSSHAuth bool
|
||||
)
|
||||
|
||||
func init() {
|
||||
@@ -57,6 +64,7 @@ func init() {
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||
|
||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||
@@ -64,11 +72,14 @@ func init() {
|
||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
|
||||
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||
|
||||
sshCmd.AddCommand(sshSftpCmd)
|
||||
sshCmd.AddCommand(sshProxyCmd)
|
||||
sshCmd.AddCommand(sshDetectCmd)
|
||||
}
|
||||
|
||||
var sshCmd = &cobra.Command{
|
||||
@@ -335,31 +346,51 @@ func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers
|
||||
}
|
||||
|
||||
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||
func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string, *string, *string, *string) {
|
||||
// sshFlags contains all SSH-related flags and parameters
|
||||
type sshFlags struct {
|
||||
Port int
|
||||
Username string
|
||||
Login string
|
||||
StrictHostKeyChecking bool
|
||||
KnownHostsFile string
|
||||
IdentityFile string
|
||||
SkipCachedToken bool
|
||||
ConfigPath string
|
||||
LogLevel string
|
||||
LocalForwards []string
|
||||
RemoteForwards []string
|
||||
Host string
|
||||
Command string
|
||||
}
|
||||
|
||||
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||
|
||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||
fs.SetOutput(nil)
|
||||
|
||||
portFlag := fs.Int("p", sshserver.DefaultSSHPort, "SSH port")
|
||||
flags := &sshFlags{}
|
||||
|
||||
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
|
||||
userFlag := fs.String("u", "", sshUsernameDesc)
|
||||
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||
fs.String("user", "", sshUsernameDesc)
|
||||
loginFlag := fs.String("login", "", sshUsernameDesc+" (alias for --user)")
|
||||
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||
|
||||
strictHostKeyCheckingFlag := fs.Bool("strict-host-key-checking", true, "Enable strict host key checking")
|
||||
knownHostsFlag := fs.String("o", "", "Path to known_hosts file")
|
||||
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||
fs.String("known-hosts", "", "Path to known_hosts file")
|
||||
identityFlag := fs.String("i", "", "Path to SSH private key file")
|
||||
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||
fs.String("identity", "", "Path to SSH private key file")
|
||||
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||
|
||||
configFlag := fs.String("c", defaultConfigPath, "Netbird config file location")
|
||||
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||
fs.String("config", defaultConfigPath, "Netbird config file location")
|
||||
logLevelFlag := fs.String("l", defaultLogLevel, "sets Netbird log level")
|
||||
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
|
||||
|
||||
return fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag
|
||||
return fs, flags
|
||||
}
|
||||
|
||||
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
@@ -375,7 +406,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
|
||||
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||
|
||||
fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag := createSSHFlagSet()
|
||||
fs, flags := createSSHFlagSet()
|
||||
|
||||
if err := fs.Parse(filteredArgs); err != nil {
|
||||
return parseHostnameAndCommand(filteredArgs)
|
||||
@@ -386,22 +417,23 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||
return errors.New(hostArgumentRequired)
|
||||
}
|
||||
|
||||
port = *portFlag
|
||||
if *userFlag != "" {
|
||||
username = *userFlag
|
||||
} else if *loginFlag != "" {
|
||||
username = *loginFlag
|
||||
port = flags.Port
|
||||
if flags.Username != "" {
|
||||
username = flags.Username
|
||||
} else if flags.Login != "" {
|
||||
username = flags.Login
|
||||
}
|
||||
|
||||
strictHostKeyChecking = *strictHostKeyCheckingFlag
|
||||
knownHostsFile = *knownHostsFlag
|
||||
identityFile = *identityFlag
|
||||
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||
knownHostsFile = flags.KnownHostsFile
|
||||
identityFile = flags.IdentityFile
|
||||
skipCachedToken = flags.SkipCachedToken
|
||||
|
||||
if *configFlag != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = *configFlag
|
||||
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||
configPath = flags.ConfigPath
|
||||
}
|
||||
if *logLevelFlag != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||
logLevel = *logLevelFlag
|
||||
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||
logLevel = flags.LogLevel
|
||||
}
|
||||
|
||||
localForwards = localForwardFlags
|
||||
@@ -449,30 +481,20 @@ func parseHostnameAndCommand(args []string) error {
|
||||
|
||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||
target := fmt.Sprintf("%s:%d", addr, port)
|
||||
|
||||
var c *sshclient.Client
|
||||
var err error
|
||||
|
||||
if strictHostKeyChecking {
|
||||
c, err = sshclient.DialWithOptions(ctx, target, username, sshclient.DialOptions{
|
||||
KnownHostsFile: knownHostsFile,
|
||||
IdentityFile: identityFile,
|
||||
DaemonAddr: daemonAddr,
|
||||
})
|
||||
} else {
|
||||
c, err = sshclient.DialInsecure(ctx, target, username)
|
||||
}
|
||||
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||
KnownHostsFile: knownHostsFile,
|
||||
IdentityFile: identityFile,
|
||||
DaemonAddr: daemonAddr,
|
||||
SkipCachedToken: skipCachedToken,
|
||||
InsecureSkipVerify: !strictHostKeyChecking,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||
cmd.Printf("\nTroubleshooting steps:\n")
|
||||
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
|
||||
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
|
||||
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
||||
if strictHostKeyChecking {
|
||||
cmd.Printf(" 4. Try --strict-host-key-checking=false to bypass host key verification\n")
|
||||
}
|
||||
cmd.Printf("\n")
|
||||
return fmt.Errorf("dial %s: %w", target, err)
|
||||
}
|
||||
|
||||
@@ -665,3 +687,65 @@ func normalizeLocalHost(host string) string {
|
||||
}
|
||||
return host
|
||||
}
|
||||
|
||||
var sshProxyCmd = &cobra.Command{
|
||||
Use: "proxy <host> <port>",
|
||||
Short: "Internal SSH proxy for native SSH client integration",
|
||||
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshProxyFn,
|
||||
}
|
||||
|
||||
func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("invalid port: %s", portStr)
|
||||
}
|
||||
|
||||
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
||||
if err != nil {
|
||||
return fmt.Errorf("create SSH proxy: %w", err)
|
||||
}
|
||||
|
||||
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||
return fmt.Errorf("SSH proxy: %w", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var sshDetectCmd = &cobra.Command{
|
||||
Use: "detect <host> <port>",
|
||||
Short: "Detect if a host is running NetBird SSH",
|
||||
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
|
||||
Hidden: true,
|
||||
Args: cobra.ExactArgs(2),
|
||||
RunE: sshDetectFn,
|
||||
}
|
||||
|
||||
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
host := args[0]
|
||||
portStr := args[1]
|
||||
|
||||
port, err := strconv.Atoi(portStr)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
||||
if err != nil {
|
||||
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||
}
|
||||
|
||||
os.Exit(serverType.ExitCode())
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -360,6 +360,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
|
||||
}
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
req.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
log.Errorf("parse interface name: %v", err)
|
||||
@@ -460,6 +463,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
ic.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(interfaceNameFlag).Changed {
|
||||
if err := parseInterfaceName(interfaceName); err != nil {
|
||||
return nil, err
|
||||
@@ -576,6 +583,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||
}
|
||||
|
||||
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||
}
|
||||
|
||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user