mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 02:06:39 +00:00
Add ssh authenatication with jwt (#4550)
This commit is contained in:
@@ -17,9 +17,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
"github.com/netbirdio/netbird/client/net"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/client/net"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"flag"
|
"flag"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/signal"
|
"os/signal"
|
||||||
"os/user"
|
"os/user"
|
||||||
"slices"
|
"slices"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
@@ -16,6 +18,8 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
sshclient "github.com/netbirdio/netbird/client/ssh/client"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
|
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
|
||||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
@@ -29,6 +33,7 @@ const (
|
|||||||
enableSSHSFTPFlag = "enable-ssh-sftp"
|
enableSSHSFTPFlag = "enable-ssh-sftp"
|
||||||
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
|
||||||
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
|
||||||
|
disableSSHAuthFlag = "disable-ssh-auth"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -41,6 +46,7 @@ var (
|
|||||||
strictHostKeyChecking bool
|
strictHostKeyChecking bool
|
||||||
knownHostsFile string
|
knownHostsFile string
|
||||||
identityFile string
|
identityFile string
|
||||||
|
skipCachedToken bool
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -49,6 +55,7 @@ var (
|
|||||||
enableSSHSFTP bool
|
enableSSHSFTP bool
|
||||||
enableSSHLocalPortForward bool
|
enableSSHLocalPortForward bool
|
||||||
enableSSHRemotePortForward bool
|
enableSSHRemotePortForward bool
|
||||||
|
disableSSHAuth bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -57,6 +64,7 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
|
||||||
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
|
||||||
|
|
||||||
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
|
||||||
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
|
||||||
@@ -64,11 +72,14 @@ func init() {
|
|||||||
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
|
||||||
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
|
||||||
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
|
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
|
||||||
|
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
|
||||||
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
|
||||||
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
|
||||||
|
|
||||||
sshCmd.AddCommand(sshSftpCmd)
|
sshCmd.AddCommand(sshSftpCmd)
|
||||||
|
sshCmd.AddCommand(sshProxyCmd)
|
||||||
|
sshCmd.AddCommand(sshDetectCmd)
|
||||||
}
|
}
|
||||||
|
|
||||||
var sshCmd = &cobra.Command{
|
var sshCmd = &cobra.Command{
|
||||||
@@ -335,31 +346,51 @@ func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers
|
|||||||
}
|
}
|
||||||
|
|
||||||
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
// createSSHFlagSet creates and configures the flag set for SSH command parsing
|
||||||
func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string, *string, *string, *string) {
|
// sshFlags contains all SSH-related flags and parameters
|
||||||
|
type sshFlags struct {
|
||||||
|
Port int
|
||||||
|
Username string
|
||||||
|
Login string
|
||||||
|
StrictHostKeyChecking bool
|
||||||
|
KnownHostsFile string
|
||||||
|
IdentityFile string
|
||||||
|
SkipCachedToken bool
|
||||||
|
ConfigPath string
|
||||||
|
LogLevel string
|
||||||
|
LocalForwards []string
|
||||||
|
RemoteForwards []string
|
||||||
|
Host string
|
||||||
|
Command string
|
||||||
|
}
|
||||||
|
|
||||||
|
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
|
||||||
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
|
||||||
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
|
||||||
|
|
||||||
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
|
||||||
fs.SetOutput(nil)
|
fs.SetOutput(nil)
|
||||||
|
|
||||||
portFlag := fs.Int("p", sshserver.DefaultSSHPort, "SSH port")
|
flags := &sshFlags{}
|
||||||
|
|
||||||
|
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
|
||||||
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
|
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
|
||||||
userFlag := fs.String("u", "", sshUsernameDesc)
|
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
|
||||||
fs.String("user", "", sshUsernameDesc)
|
fs.String("user", "", sshUsernameDesc)
|
||||||
loginFlag := fs.String("login", "", sshUsernameDesc+" (alias for --user)")
|
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
|
||||||
|
|
||||||
strictHostKeyCheckingFlag := fs.Bool("strict-host-key-checking", true, "Enable strict host key checking")
|
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
|
||||||
knownHostsFlag := fs.String("o", "", "Path to known_hosts file")
|
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
|
||||||
fs.String("known-hosts", "", "Path to known_hosts file")
|
fs.String("known-hosts", "", "Path to known_hosts file")
|
||||||
identityFlag := fs.String("i", "", "Path to SSH private key file")
|
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
|
||||||
fs.String("identity", "", "Path to SSH private key file")
|
fs.String("identity", "", "Path to SSH private key file")
|
||||||
|
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
|
||||||
|
|
||||||
configFlag := fs.String("c", defaultConfigPath, "Netbird config file location")
|
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
|
||||||
fs.String("config", defaultConfigPath, "Netbird config file location")
|
fs.String("config", defaultConfigPath, "Netbird config file location")
|
||||||
logLevelFlag := fs.String("l", defaultLogLevel, "sets Netbird log level")
|
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
|
||||||
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
|
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
|
||||||
|
|
||||||
return fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag
|
return fs, flags
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
||||||
@@ -375,7 +406,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
|
||||||
|
|
||||||
fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag := createSSHFlagSet()
|
fs, flags := createSSHFlagSet()
|
||||||
|
|
||||||
if err := fs.Parse(filteredArgs); err != nil {
|
if err := fs.Parse(filteredArgs); err != nil {
|
||||||
return parseHostnameAndCommand(filteredArgs)
|
return parseHostnameAndCommand(filteredArgs)
|
||||||
@@ -386,22 +417,23 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
|
|||||||
return errors.New(hostArgumentRequired)
|
return errors.New(hostArgumentRequired)
|
||||||
}
|
}
|
||||||
|
|
||||||
port = *portFlag
|
port = flags.Port
|
||||||
if *userFlag != "" {
|
if flags.Username != "" {
|
||||||
username = *userFlag
|
username = flags.Username
|
||||||
} else if *loginFlag != "" {
|
} else if flags.Login != "" {
|
||||||
username = *loginFlag
|
username = flags.Login
|
||||||
}
|
}
|
||||||
|
|
||||||
strictHostKeyChecking = *strictHostKeyCheckingFlag
|
strictHostKeyChecking = flags.StrictHostKeyChecking
|
||||||
knownHostsFile = *knownHostsFlag
|
knownHostsFile = flags.KnownHostsFile
|
||||||
identityFile = *identityFlag
|
identityFile = flags.IdentityFile
|
||||||
|
skipCachedToken = flags.SkipCachedToken
|
||||||
|
|
||||||
if *configFlag != getEnvOrDefault("CONFIG", configPath) {
|
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
|
||||||
configPath = *configFlag
|
configPath = flags.ConfigPath
|
||||||
}
|
}
|
||||||
if *logLevelFlag != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
|
||||||
logLevel = *logLevelFlag
|
logLevel = flags.LogLevel
|
||||||
}
|
}
|
||||||
|
|
||||||
localForwards = localForwardFlags
|
localForwards = localForwardFlags
|
||||||
@@ -449,30 +481,20 @@ func parseHostnameAndCommand(args []string) error {
|
|||||||
|
|
||||||
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
|
||||||
target := fmt.Sprintf("%s:%d", addr, port)
|
target := fmt.Sprintf("%s:%d", addr, port)
|
||||||
|
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
|
||||||
var c *sshclient.Client
|
KnownHostsFile: knownHostsFile,
|
||||||
var err error
|
IdentityFile: identityFile,
|
||||||
|
DaemonAddr: daemonAddr,
|
||||||
if strictHostKeyChecking {
|
SkipCachedToken: skipCachedToken,
|
||||||
c, err = sshclient.DialWithOptions(ctx, target, username, sshclient.DialOptions{
|
InsecureSkipVerify: !strictHostKeyChecking,
|
||||||
KnownHostsFile: knownHostsFile,
|
})
|
||||||
IdentityFile: identityFile,
|
|
||||||
DaemonAddr: daemonAddr,
|
|
||||||
})
|
|
||||||
} else {
|
|
||||||
c, err = sshclient.DialInsecure(ctx, target, username)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
cmd.Printf("Failed to connect to %s@%s\n", username, target)
|
||||||
cmd.Printf("\nTroubleshooting steps:\n")
|
cmd.Printf("\nTroubleshooting steps:\n")
|
||||||
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
|
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
|
||||||
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
|
||||||
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
|
||||||
if strictHostKeyChecking {
|
|
||||||
cmd.Printf(" 4. Try --strict-host-key-checking=false to bypass host key verification\n")
|
|
||||||
}
|
|
||||||
cmd.Printf("\n")
|
|
||||||
return fmt.Errorf("dial %s: %w", target, err)
|
return fmt.Errorf("dial %s: %w", target, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -665,3 +687,65 @@ func normalizeLocalHost(host string) string {
|
|||||||
}
|
}
|
||||||
return host
|
return host
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var sshProxyCmd = &cobra.Command{
|
||||||
|
Use: "proxy <host> <port>",
|
||||||
|
Short: "Internal SSH proxy for native SSH client integration",
|
||||||
|
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
|
||||||
|
Hidden: true,
|
||||||
|
Args: cobra.ExactArgs(2),
|
||||||
|
RunE: sshProxyFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshProxyFn(cmd *cobra.Command, args []string) error {
|
||||||
|
host := args[0]
|
||||||
|
portStr := args[1]
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid port: %s", portStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create SSH proxy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := proxy.Connect(cmd.Context()); err != nil {
|
||||||
|
return fmt.Errorf("SSH proxy: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var sshDetectCmd = &cobra.Command{
|
||||||
|
Use: "detect <host> <port>",
|
||||||
|
Short: "Detect if a host is running NetBird SSH",
|
||||||
|
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
|
||||||
|
Hidden: true,
|
||||||
|
Args: cobra.ExactArgs(2),
|
||||||
|
RunE: sshDetectFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func sshDetectFn(cmd *cobra.Command, args []string) error {
|
||||||
|
if err := util.InitLog(logLevel, "console"); err != nil {
|
||||||
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
host := args[0]
|
||||||
|
portStr := args[1]
|
||||||
|
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||||
|
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
|
||||||
|
if err != nil {
|
||||||
|
os.Exit(detection.ServerTypeRegular.ExitCode())
|
||||||
|
}
|
||||||
|
|
||||||
|
os.Exit(serverType.ExitCode())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -360,6 +360,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
|
|||||||
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
|
||||||
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
|
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
|
||||||
}
|
}
|
||||||
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
|
req.DisableSSHAuth = &disableSSHAuth
|
||||||
|
}
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
log.Errorf("parse interface name: %v", err)
|
log.Errorf("parse interface name: %v", err)
|
||||||
@@ -460,6 +463,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
|
|||||||
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
|
ic.DisableSSHAuth = &disableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -576,6 +583,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
|
|||||||
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableSSHAuthFlag).Changed {
|
||||||
|
loginRequest.DisableSSHAuth = &disableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
if cmd.Flag(disableAutoConnectFlag).Changed {
|
if cmd.Flag(disableAutoConnectFlag).Changed {
|
||||||
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
loginRequest.DisableAutoConnect = &autoConnectDisabled
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -18,12 +18,16 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
var (
|
||||||
var ErrClientNotStarted = errors.New("client not started")
|
ErrClientAlreadyStarted = errors.New("client already started")
|
||||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrClientNotStarted = errors.New("client not started")
|
||||||
|
ErrEngineNotStarted = errors.New("engine not started")
|
||||||
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
|
)
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
|||||||
// Dial dials a network address in the netbird network.
|
// Dial dials a network address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
c.mu.Lock()
|
engine, err := c.getEngine()
|
||||||
connect := c.connect
|
if err != nil {
|
||||||
if connect == nil {
|
return nil, err
|
||||||
c.mu.Unlock()
|
|
||||||
return nil, ErrClientNotStarted
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
|
||||||
|
|
||||||
engine := connect.Engine()
|
|
||||||
if engine == nil {
|
|
||||||
return nil, errors.New("engine not started")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
nsnet, err := engine.GetNet()
|
nsnet, err := engine.GetNet()
|
||||||
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
|||||||
return nsnet.DialContext(ctx, network, address)
|
return nsnet.DialContext(ctx, network, address)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// DialContext dials a network address in the netbird network with context
|
||||||
|
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
return c.Dial(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
// ListenTCP listens on the given address in the netbird network.
|
// ListenTCP listens on the given address in the netbird network.
|
||||||
// Not applicable if the userspace networking mode is disabled.
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||||
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||||
|
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||||
|
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||||
|
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
storedKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||||
|
if !found {
|
||||||
|
return sshcommon.ErrPeerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
// getEngine safely retrieves the engine from the client with proper locking.
|
||||||
|
// Returns ErrClientNotStarted if the client is not started.
|
||||||
|
// Returns ErrEngineNotStarted if the engine is not available.
|
||||||
|
func (c *Client) getEngine() (*internal.Engine, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
if connect == nil {
|
|
||||||
c.mu.Unlock()
|
|
||||||
return nil, netip.Addr{}, errors.New("client not started")
|
|
||||||
}
|
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
if connect == nil {
|
||||||
|
return nil, ErrClientNotStarted
|
||||||
|
}
|
||||||
|
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine == nil {
|
if engine == nil {
|
||||||
return nil, netip.Addr{}, errors.New("engine not started")
|
return nil, ErrEngineNotStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
return engine, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return nil, netip.Addr{}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
addr, err := engine.Address()
|
addr, err := engine.Address()
|
||||||
|
|||||||
@@ -87,7 +87,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
|
|||||||
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
||||||
rules := d.squashAcceptRules(networkMap)
|
rules := d.squashAcceptRules(networkMap)
|
||||||
|
|
||||||
|
|
||||||
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
|
||||||
// we have old version of management without rules handling, we should allow all traffic
|
// we have old version of management without rules handling, we should allow all traffic
|
||||||
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
|
||||||
@@ -350,7 +349,7 @@ func (d *DefaultManager) getPeerRuleID(
|
|||||||
//
|
//
|
||||||
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
|
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
|
||||||
// but other has port definitions or has drop policy.
|
// but other has port definitions or has drop policy.
|
||||||
func (d *DefaultManager) squashAcceptRules(networkMap *mgmProto.NetworkMap, ) []*mgmProto.FirewallRule {
|
func (d *DefaultManager) squashAcceptRules(networkMap *mgmProto.NetworkMap) []*mgmProto.FirewallRule {
|
||||||
totalIPs := 0
|
totalIPs := 0
|
||||||
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
|
||||||
for range p.AllowedIps {
|
for range p.AllowedIps {
|
||||||
|
|||||||
@@ -25,6 +25,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
@@ -34,7 +35,6 @@ import (
|
|||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -437,6 +437,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
|
|||||||
EnableSSHSFTP: config.EnableSSHSFTP,
|
EnableSSHSFTP: config.EnableSSHSFTP,
|
||||||
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
|
||||||
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
|
||||||
|
DisableSSHAuth: config.DisableSSHAuth,
|
||||||
DNSRouteInterval: config.DNSRouteInterval,
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
DisableClientRoutes: config.DisableClientRoutes,
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
@@ -527,6 +528,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.EnableSSHSFTP,
|
config.EnableSSHSFTP,
|
||||||
config.EnableSSHLocalPortForwarding,
|
config.EnableSSHLocalPortForwarding,
|
||||||
config.EnableSSHRemotePortForwarding,
|
config.EnableSSHRemotePortForwarding,
|
||||||
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -49,6 +49,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||||
|
|
||||||
@@ -117,6 +118,7 @@ type EngineConfig struct {
|
|||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
|
DisableSSHAuth *bool
|
||||||
|
|
||||||
DNSRouteInterval time.Duration
|
DNSRouteInterval time.Duration
|
||||||
|
|
||||||
@@ -264,6 +266,7 @@ func NewEngine(
|
|||||||
path = mobileDep.StateFilePath
|
path = mobileDep.StateFilePath
|
||||||
}
|
}
|
||||||
engine.stateManager = statemanager.New(path)
|
engine.stateManager = statemanager.New(path)
|
||||||
|
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||||
|
|
||||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||||
return engine
|
return engine
|
||||||
@@ -676,14 +679,10 @@ func (e *Engine) removeAllPeers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
|
// removePeer closes an existing peer connection and removes a peer
|
||||||
func (e *Engine) removePeer(peerKey string) error {
|
func (e *Engine) removePeer(peerKey string) error {
|
||||||
log.Debugf("removing peer from engine %s", peerKey)
|
log.Debugf("removing peer from engine %s", peerKey)
|
||||||
|
|
||||||
if e.sshServer != nil {
|
|
||||||
e.sshServer.RemoveAuthorizedKey(peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
e.connMgr.RemovePeerConn(peerKey)
|
e.connMgr.RemovePeerConn(peerKey)
|
||||||
|
|
||||||
err := e.statusRecorder.RemovePeer(peerKey)
|
err := e.statusRecorder.RemovePeer(peerKey)
|
||||||
@@ -859,6 +858,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
|||||||
e.config.EnableSSHSFTP,
|
e.config.EnableSSHSFTP,
|
||||||
e.config.EnableSSHLocalPortForwarding,
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := e.mgmClient.SyncMeta(info); err != nil {
|
if err := e.mgmClient.SyncMeta(info); err != nil {
|
||||||
@@ -920,6 +920,7 @@ func (e *Engine) receiveManagementEvents() {
|
|||||||
e.config.EnableSSHSFTP,
|
e.config.EnableSSHSFTP,
|
||||||
e.config.EnableSSHLocalPortForwarding,
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
|
||||||
@@ -1074,24 +1075,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
e.statusRecorder.FinishPeerListModifications()
|
e.statusRecorder.FinishPeerListModifications()
|
||||||
|
|
||||||
// update SSHServer by adding remote peer SSH keys
|
|
||||||
if e.sshServer != nil {
|
|
||||||
for _, config := range networkMap.GetRemotePeers() {
|
|
||||||
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
|
|
||||||
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// update peer SSH host keys in status recorder for daemon API access
|
|
||||||
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
|
||||||
|
|
||||||
// update SSH client known_hosts with peer host keys for OpenSSH client
|
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
|
||||||
if err := e.updateSSHKnownHosts(networkMap.GetRemotePeers()); err != nil {
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
log.Warnf("failed to update SSH known_hosts: %v", err)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1480,6 +1467,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
|
|||||||
e.config.EnableSSHSFTP,
|
e.config.EnableSSHSFTP,
|
||||||
e.config.EnableSSHLocalPortForwarding,
|
e.config.EnableSSHLocalPortForwarding,
|
||||||
e.config.EnableSSHRemotePortForwarding,
|
e.config.EnableSSHRemotePortForwarding,
|
||||||
|
e.config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
|
|
||||||
netMap, err := e.mgmClient.GetNetworkMap(info)
|
netMap, err := e.mgmClient.GetNetworkMap(info)
|
||||||
|
|||||||
@@ -5,10 +5,8 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/gliderlabs/ssh"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
@@ -21,9 +19,6 @@ import (
|
|||||||
type sshServer interface {
|
type sshServer interface {
|
||||||
Start(ctx context.Context, addr netip.AddrPort) error
|
Start(ctx context.Context, addr netip.AddrPort) error
|
||||||
Stop() error
|
Stop() error
|
||||||
RemoveAuthorizedKey(peer string)
|
|
||||||
AddAuthorizedKey(peer, newKey string) error
|
|
||||||
SetSocketFilter(ifIdx int)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) setupSSHPortRedirection() error {
|
func (e *Engine) setupSSHPortRedirection() error {
|
||||||
@@ -44,22 +39,6 @@ func (e *Engine) setupSSHPortRedirection() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) setupSSHSocketFilter(server sshServer) error {
|
|
||||||
if runtime.GOOS != "linux" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
netInterface := e.wgInterface.ToInterface()
|
|
||||||
if netInterface == nil {
|
|
||||||
return errors.New("failed to get WireGuard network interface")
|
|
||||||
}
|
|
||||||
|
|
||||||
server.SetSocketFilter(netInterface.Index)
|
|
||||||
log.Debugf("SSH socket filter configured for interface %s (index: %d)", netInterface.Name, netInterface.Index)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
||||||
if e.config.BlockInbound {
|
if e.config.BlockInbound {
|
||||||
log.Info("SSH server is disabled because inbound connections are blocked")
|
log.Info("SSH server is disabled because inbound connections are blocked")
|
||||||
@@ -83,66 +62,76 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
return e.startSSHServer()
|
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
|
||||||
|
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
|
||||||
|
return e.startSSHServer(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
|
||||||
|
jwtConfig := &sshserver.JWTConfig{
|
||||||
|
Issuer: protoJWT.GetIssuer(),
|
||||||
|
Audience: protoJWT.GetAudience(),
|
||||||
|
KeysLocation: protoJWT.GetKeysLocation(),
|
||||||
|
MaxTokenAge: protoJWT.GetMaxTokenAge(),
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.startSSHServer(jwtConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
return errors.New("SSH server requires valid JWT configuration")
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateSSHKnownHosts updates the SSH known_hosts file with peer host keys for OpenSSH client
|
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||||
func (e *Engine) updateSSHKnownHosts(remotePeers []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||||
peerKeys := e.extractPeerHostKeys(remotePeers)
|
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||||
if len(peerKeys) == 0 {
|
if len(peerInfo) == 0 {
|
||||||
log.Debug("no SSH-enabled peers found, skipping known_hosts update")
|
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.updateKnownHostsFile(peerKeys); err != nil {
|
configMgr := sshconfig.New()
|
||||||
return err
|
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
|
||||||
|
log.Warnf("failed to update SSH client config: %v", err)
|
||||||
|
return nil // Don't fail engine startup on SSH config issues
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
|
||||||
|
|
||||||
|
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
|
||||||
|
SSHConfigDir: configMgr.GetSSHConfigDir(),
|
||||||
|
SSHConfigFile: configMgr.GetSSHConfigFile(),
|
||||||
|
}); err != nil {
|
||||||
|
log.Warnf("failed to update SSH config state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
e.updateSSHClientConfig(peerKeys)
|
|
||||||
log.Debugf("updated SSH known_hosts with %d peer host keys", len(peerKeys))
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractPeerHostKeys extracts SSH host keys from peer configurations
|
// extractPeerSSHInfo extracts SSH information from peer configurations
|
||||||
func (e *Engine) extractPeerHostKeys(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerHostKey {
|
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
|
||||||
var peerKeys []sshconfig.PeerHostKey
|
var peerInfo []sshconfig.PeerSSHInfo
|
||||||
|
|
||||||
for _, peerConfig := range remotePeers {
|
for _, peerConfig := range remotePeers {
|
||||||
peerHostKey, ok := e.parsePeerHostKey(peerConfig)
|
if peerConfig.GetSshConfig() == nil {
|
||||||
if ok {
|
continue
|
||||||
peerKeys = append(peerKeys, peerHostKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
||||||
|
if len(sshPubKeyBytes) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerIP := e.extractPeerIP(peerConfig)
|
||||||
|
hostname := e.extractHostname(peerConfig)
|
||||||
|
|
||||||
|
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
|
||||||
|
Hostname: hostname,
|
||||||
|
IP: peerIP,
|
||||||
|
FQDN: peerConfig.GetFqdn(),
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
return peerKeys
|
return peerInfo
|
||||||
}
|
|
||||||
|
|
||||||
// parsePeerHostKey parses a single peer's SSH host key configuration
|
|
||||||
func (e *Engine) parsePeerHostKey(peerConfig *mgmProto.RemotePeerConfig) (sshconfig.PeerHostKey, bool) {
|
|
||||||
if peerConfig.GetSshConfig() == nil {
|
|
||||||
return sshconfig.PeerHostKey{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
|
|
||||||
if len(sshPubKeyBytes) == 0 {
|
|
||||||
return sshconfig.PeerHostKey{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
hostKey, _, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyBytes)
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse SSH public key for peer %s: %v", peerConfig.GetWgPubKey(), err)
|
|
||||||
return sshconfig.PeerHostKey{}, false
|
|
||||||
}
|
|
||||||
|
|
||||||
peerIP := e.extractPeerIP(peerConfig)
|
|
||||||
hostname := e.extractHostname(peerConfig)
|
|
||||||
|
|
||||||
return sshconfig.PeerHostKey{
|
|
||||||
Hostname: hostname,
|
|
||||||
IP: peerIP,
|
|
||||||
FQDN: peerConfig.GetFqdn(),
|
|
||||||
HostKey: hostKey,
|
|
||||||
}, true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractPeerIP extracts IP address from peer's allowed IPs
|
// extractPeerIP extracts IP address from peer's allowed IPs
|
||||||
@@ -171,25 +160,6 @@ func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
|
|||||||
return ""
|
return ""
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateKnownHostsFile updates the SSH known_hosts file
|
|
||||||
func (e *Engine) updateKnownHostsFile(peerKeys []sshconfig.PeerHostKey) error {
|
|
||||||
configMgr := sshconfig.NewManager()
|
|
||||||
if err := configMgr.UpdatePeerHostKeys(peerKeys); err != nil {
|
|
||||||
return fmt.Errorf("update peer host keys: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// updateSSHClientConfig updates SSH client configuration with peer hostnames
|
|
||||||
func (e *Engine) updateSSHClientConfig(peerKeys []sshconfig.PeerHostKey) {
|
|
||||||
configMgr := sshconfig.NewManager()
|
|
||||||
if err := configMgr.SetupSSHClientConfig(peerKeys); err != nil {
|
|
||||||
log.Warnf("failed to update SSH client config with peer hostnames: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Debugf("updated SSH client config with %d peer hostnames", len(peerKeys))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
|
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
|
||||||
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
|
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
|
||||||
for _, peerConfig := range remotePeers {
|
for _, peerConfig := range remotePeers {
|
||||||
@@ -210,30 +180,51 @@ func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig)
|
|||||||
log.Debugf("updated peer SSH host keys for daemon API access")
|
log.Debugf("updated peer SSH host keys for daemon API access")
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
|
||||||
|
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
statusRecorder := e.statusRecorder
|
||||||
|
e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
|
if statusRecorder == nil {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
fullStatus := statusRecorder.GetFullStatus()
|
||||||
|
for _, peerState := range fullStatus.Peers {
|
||||||
|
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||||
|
if len(peerState.SSHHostKey) > 0 {
|
||||||
|
return peerState.SSHHostKey, true
|
||||||
|
}
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||||
func (e *Engine) cleanupSSHConfig() {
|
func (e *Engine) cleanupSSHConfig() {
|
||||||
configMgr := sshconfig.NewManager()
|
configMgr := sshconfig.New()
|
||||||
|
|
||||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||||
log.Warnf("failed to remove SSH client config: %v", err)
|
log.Warnf("failed to remove SSH client config: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("SSH client config cleanup completed")
|
log.Debugf("SSH client config cleanup completed")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := configMgr.RemoveKnownHostsFile(); err != nil {
|
|
||||||
log.Warnf("failed to remove SSH known_hosts: %v", err)
|
|
||||||
} else {
|
|
||||||
log.Debugf("SSH known_hosts cleanup completed")
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// startSSHServer initializes and starts the SSH server with proper configuration.
|
// startSSHServer initializes and starts the SSH server with proper configuration.
|
||||||
func (e *Engine) startSSHServer() error {
|
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
|
||||||
if e.wgInterface == nil {
|
if e.wgInterface == nil {
|
||||||
return errors.New("wg interface not initialized")
|
return errors.New("wg interface not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
server := sshserver.New(e.config.SSHKey)
|
serverConfig := &sshserver.Config{
|
||||||
|
HostKeyPEM: e.config.SSHKey,
|
||||||
|
JWT: jwtConfig,
|
||||||
|
}
|
||||||
|
server := sshserver.New(serverConfig)
|
||||||
|
|
||||||
wgAddr := e.wgInterface.Address()
|
wgAddr := e.wgInterface.Address()
|
||||||
server.SetNetworkValidation(wgAddr)
|
server.SetNetworkValidation(wgAddr)
|
||||||
@@ -259,15 +250,10 @@ func (e *Engine) startSSHServer() error {
|
|||||||
log.Warnf("failed to setup SSH port redirection: %v", err)
|
log.Warnf("failed to setup SSH port redirection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := e.setupSSHSocketFilter(server); err != nil {
|
|
||||||
return fmt.Errorf("set socket filter: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := server.Start(e.ctx, listenAddr); err != nil {
|
if err := server.Start(e.ctx, listenAddr); err != nil {
|
||||||
return fmt.Errorf("start SSH server: %w", err)
|
return fmt.Errorf("start SSH server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -281,7 +281,15 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
networkMap = &mgmtProto.NetworkMap{
|
networkMap = &mgmtProto.NetworkMap{
|
||||||
Serial: 7,
|
Serial: 7,
|
||||||
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
|
||||||
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
|
SshConfig: &mgmtProto.SSHConfig{
|
||||||
|
SshEnabled: true,
|
||||||
|
JwtConfig: &mgmtProto.JWTConfig{
|
||||||
|
Issuer: "test-issuer",
|
||||||
|
Audience: "test-audience",
|
||||||
|
KeysLocation: "test-keys",
|
||||||
|
MaxTokenAge: 3600,
|
||||||
|
},
|
||||||
|
}},
|
||||||
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
|
||||||
RemotePeersIsEmpty: false,
|
RemotePeersIsEmpty: false,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -128,6 +128,7 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
|
|||||||
config.EnableSSHSFTP,
|
config.EnableSSHSFTP,
|
||||||
config.EnableSSHLocalPortForwarding,
|
config.EnableSSHLocalPortForwarding,
|
||||||
config.EnableSSHRemotePortForwarding,
|
config.EnableSSHRemotePortForwarding,
|
||||||
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
return serverKey, loginResp, err
|
return serverKey, loginResp, err
|
||||||
@@ -158,6 +159,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
|
|||||||
config.EnableSSHSFTP,
|
config.EnableSSHSFTP,
|
||||||
config.EnableSSHLocalPortForwarding,
|
config.EnableSSHLocalPortForwarding,
|
||||||
config.EnableSSHRemotePortForwarding,
|
config.EnableSSHRemotePortForwarding,
|
||||||
|
config.DisableSSHAuth,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -21,9 +21,9 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
"github.com/netbirdio/netbird/client/internal/relay"
|
"github.com/netbirdio/netbird/client/internal/relay"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
"github.com/netbirdio/netbird/route"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const eventQueueSize = 10
|
const eventQueueSize = 10
|
||||||
|
|||||||
@@ -54,6 +54,7 @@ type ConfigInput struct {
|
|||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
|
DisableSSHAuth *bool
|
||||||
NATExternalIPs []string
|
NATExternalIPs []string
|
||||||
CustomDNSAddress []byte
|
CustomDNSAddress []byte
|
||||||
RosenpassEnabled *bool
|
RosenpassEnabled *bool
|
||||||
@@ -102,6 +103,7 @@ type Config struct {
|
|||||||
EnableSSHSFTP *bool
|
EnableSSHSFTP *bool
|
||||||
EnableSSHLocalPortForwarding *bool
|
EnableSSHLocalPortForwarding *bool
|
||||||
EnableSSHRemotePortForwarding *bool
|
EnableSSHRemotePortForwarding *bool
|
||||||
|
DisableSSHAuth *bool
|
||||||
|
|
||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
@@ -423,6 +425,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
|
||||||
|
if *input.DisableSSHAuth {
|
||||||
|
log.Infof("disabling SSH authentication")
|
||||||
|
} else {
|
||||||
|
log.Infof("enabling SSH authentication")
|
||||||
|
}
|
||||||
|
config.DisableSSHAuth = input.DisableSSHAuth
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
|
||||||
log.Infof("updating DNS route interval to %s (old value %s)",
|
log.Infof("updating DNS route interval to %s (old value %s)",
|
||||||
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())
|
||||||
|
|||||||
@@ -18,8 +18,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ConnectionListener export internal Listener for mobile
|
// ConnectionListener export internal Listener for mobile
|
||||||
|
|||||||
@@ -283,6 +283,7 @@ type LoginRequest struct {
|
|||||||
EnableSSHSFTP *bool `protobuf:"varint,34,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
|
EnableSSHSFTP *bool `protobuf:"varint,34,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
|
||||||
EnableSSHLocalPortForwarding *bool `protobuf:"varint,35,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"`
|
EnableSSHLocalPortForwarding *bool `protobuf:"varint,35,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"`
|
||||||
EnableSSHRemotePortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
|
EnableSSHRemotePortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||||
|
DisableSSHAuth *bool `protobuf:"varint,37,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@@ -570,6 +571,13 @@ func (x *LoginRequest) GetEnableSSHRemotePortForwarding() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *LoginRequest) GetDisableSSHAuth() bool {
|
||||||
|
if x != nil && x.DisableSSHAuth != nil {
|
||||||
|
return *x.DisableSSHAuth
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type LoginResponse struct {
|
type LoginResponse struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
|
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
|
||||||
@@ -1100,6 +1108,7 @@ type GetConfigResponse struct {
|
|||||||
EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"`
|
EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"`
|
||||||
EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"`
|
EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"`
|
||||||
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
|
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
|
||||||
|
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@@ -1302,6 +1311,13 @@ func (x *GetConfigResponse) GetEnableSSHRemotePortForwarding() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *GetConfigResponse) GetDisableSSHAuth() bool {
|
||||||
|
if x != nil {
|
||||||
|
return x.DisableSSHAuth
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
// PeerState contains the latest state of a peer
|
||||||
type PeerState struct {
|
type PeerState struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
@@ -3781,6 +3797,7 @@ type SetConfigRequest struct {
|
|||||||
EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
|
EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
|
||||||
EnableSSHLocalPortForward *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForward,proto3,oneof" json:"enableSSHLocalPortForward,omitempty"`
|
EnableSSHLocalPortForward *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForward,proto3,oneof" json:"enableSSHLocalPortForward,omitempty"`
|
||||||
EnableSSHRemotePortForward *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForward,proto3,oneof" json:"enableSSHRemotePortForward,omitempty"`
|
EnableSSHRemotePortForward *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForward,proto3,oneof" json:"enableSSHRemotePortForward,omitempty"`
|
||||||
|
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
sizeCache protoimpl.SizeCache
|
sizeCache protoimpl.SizeCache
|
||||||
}
|
}
|
||||||
@@ -4039,6 +4056,13 @@ func (x *SetConfigRequest) GetEnableSSHRemotePortForward() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (x *SetConfigRequest) GetDisableSSHAuth() bool {
|
||||||
|
if x != nil && x.DisableSSHAuth != nil {
|
||||||
|
return *x.DisableSSHAuth
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
type SetConfigResponse struct {
|
type SetConfigResponse struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
unknownFields protoimpl.UnknownFields
|
unknownFields protoimpl.UnknownFields
|
||||||
@@ -4774,6 +4798,262 @@ func (x *GetPeerSSHHostKeyResponse) GetFound() bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequestJWTAuthRequest for initiating JWT authentication flow
|
||||||
|
type RequestJWTAuthRequest struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthRequest) Reset() {
|
||||||
|
*x = RequestJWTAuthRequest{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[71]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*RequestJWTAuthRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[71]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{71}
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestJWTAuthResponse contains authentication flow information
|
||||||
|
type RequestJWTAuthResponse struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
// verification URI for user authentication
|
||||||
|
VerificationURI string `protobuf:"bytes,1,opt,name=verificationURI,proto3" json:"verificationURI,omitempty"`
|
||||||
|
// complete verification URI (with embedded user code)
|
||||||
|
VerificationURIComplete string `protobuf:"bytes,2,opt,name=verificationURIComplete,proto3" json:"verificationURIComplete,omitempty"`
|
||||||
|
// user code to enter on verification URI
|
||||||
|
UserCode string `protobuf:"bytes,3,opt,name=userCode,proto3" json:"userCode,omitempty"`
|
||||||
|
// device code for polling
|
||||||
|
DeviceCode string `protobuf:"bytes,4,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
|
||||||
|
// expiration time in seconds
|
||||||
|
ExpiresIn int64 `protobuf:"varint,5,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
|
||||||
|
// if a cached token is available, it will be returned here
|
||||||
|
CachedToken string `protobuf:"bytes,6,opt,name=cachedToken,proto3" json:"cachedToken,omitempty"`
|
||||||
|
// maximum age of JWT tokens in seconds (from management server)
|
||||||
|
MaxTokenAge int64 `protobuf:"varint,7,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) Reset() {
|
||||||
|
*x = RequestJWTAuthResponse{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[72]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*RequestJWTAuthResponse) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[72]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead.
|
||||||
|
func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{72}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) GetVerificationURI() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.VerificationURI
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) GetVerificationURIComplete() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.VerificationURIComplete
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) GetUserCode() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.UserCode
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) GetDeviceCode() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DeviceCode
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) GetExpiresIn() int64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.ExpiresIn
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) GetCachedToken() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.CachedToken
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *RequestJWTAuthResponse) GetMaxTokenAge() int64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.MaxTokenAge
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitJWTTokenRequest for waiting for authentication completion
|
||||||
|
type WaitJWTTokenRequest struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
// device code from RequestJWTAuthResponse
|
||||||
|
DeviceCode string `protobuf:"bytes,1,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
|
||||||
|
// user code for verification
|
||||||
|
UserCode string `protobuf:"bytes,2,opt,name=userCode,proto3" json:"userCode,omitempty"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenRequest) Reset() {
|
||||||
|
*x = WaitJWTTokenRequest{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[73]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenRequest) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*WaitJWTTokenRequest) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[73]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead.
|
||||||
|
func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{73}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenRequest) GetDeviceCode() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.DeviceCode
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenRequest) GetUserCode() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.UserCode
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitJWTTokenResponse contains the JWT token after authentication
|
||||||
|
type WaitJWTTokenResponse struct {
|
||||||
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
|
// JWT token (access token or ID token)
|
||||||
|
Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"`
|
||||||
|
// token type (e.g., "Bearer")
|
||||||
|
TokenType string `protobuf:"bytes,2,opt,name=tokenType,proto3" json:"tokenType,omitempty"`
|
||||||
|
// expiration time in seconds
|
||||||
|
ExpiresIn int64 `protobuf:"varint,3,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
|
||||||
|
unknownFields protoimpl.UnknownFields
|
||||||
|
sizeCache protoimpl.SizeCache
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenResponse) Reset() {
|
||||||
|
*x = WaitJWTTokenResponse{}
|
||||||
|
mi := &file_daemon_proto_msgTypes[74]
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenResponse) String() string {
|
||||||
|
return protoimpl.X.MessageStringOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (*WaitJWTTokenResponse) ProtoMessage() {}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message {
|
||||||
|
mi := &file_daemon_proto_msgTypes[74]
|
||||||
|
if x != nil {
|
||||||
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
|
if ms.LoadMessageInfo() == nil {
|
||||||
|
ms.StoreMessageInfo(mi)
|
||||||
|
}
|
||||||
|
return ms
|
||||||
|
}
|
||||||
|
return mi.MessageOf(x)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead.
|
||||||
|
func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) {
|
||||||
|
return file_daemon_proto_rawDescGZIP(), []int{74}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenResponse) GetToken() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.Token
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenResponse) GetTokenType() string {
|
||||||
|
if x != nil {
|
||||||
|
return x.TokenType
|
||||||
|
}
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
|
||||||
|
if x != nil {
|
||||||
|
return x.ExpiresIn
|
||||||
|
}
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
type PortInfo_Range struct {
|
type PortInfo_Range struct {
|
||||||
state protoimpl.MessageState `protogen:"open.v1"`
|
state protoimpl.MessageState `protogen:"open.v1"`
|
||||||
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
|
||||||
@@ -4784,7 +5064,7 @@ type PortInfo_Range struct {
|
|||||||
|
|
||||||
func (x *PortInfo_Range) Reset() {
|
func (x *PortInfo_Range) Reset() {
|
||||||
*x = PortInfo_Range{}
|
*x = PortInfo_Range{}
|
||||||
mi := &file_daemon_proto_msgTypes[72]
|
mi := &file_daemon_proto_msgTypes[76]
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
ms.StoreMessageInfo(mi)
|
ms.StoreMessageInfo(mi)
|
||||||
}
|
}
|
||||||
@@ -4796,7 +5076,7 @@ func (x *PortInfo_Range) String() string {
|
|||||||
func (*PortInfo_Range) ProtoMessage() {}
|
func (*PortInfo_Range) ProtoMessage() {}
|
||||||
|
|
||||||
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
|
||||||
mi := &file_daemon_proto_msgTypes[72]
|
mi := &file_daemon_proto_msgTypes[76]
|
||||||
if x != nil {
|
if x != nil {
|
||||||
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
|
||||||
if ms.LoadMessageInfo() == nil {
|
if ms.LoadMessageInfo() == nil {
|
||||||
@@ -4831,7 +5111,7 @@ var File_daemon_proto protoreflect.FileDescriptor
|
|||||||
const file_daemon_proto_rawDesc = "" +
|
const file_daemon_proto_rawDesc = "" +
|
||||||
"\n" +
|
"\n" +
|
||||||
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
|
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
|
||||||
"\fEmptyRequest\"\x94\x11\n" +
|
"\fEmptyRequest\"\xd4\x11\n" +
|
||||||
"\fLoginRequest\x12\x1a\n" +
|
"\fLoginRequest\x12\x1a\n" +
|
||||||
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
|
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
|
||||||
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
|
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
|
||||||
@@ -4872,7 +5152,8 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\renableSSHRoot\x18! \x01(\bH\x14R\renableSSHRoot\x88\x01\x01\x12)\n" +
|
"\renableSSHRoot\x18! \x01(\bH\x14R\renableSSHRoot\x88\x01\x01\x12)\n" +
|
||||||
"\renableSSHSFTP\x18\" \x01(\bH\x15R\renableSSHSFTP\x88\x01\x01\x12G\n" +
|
"\renableSSHSFTP\x18\" \x01(\bH\x15R\renableSSHSFTP\x88\x01\x01\x12G\n" +
|
||||||
"\x1cenableSSHLocalPortForwarding\x18# \x01(\bH\x16R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" +
|
"\x1cenableSSHLocalPortForwarding\x18# \x01(\bH\x16R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" +
|
||||||
"\x1denableSSHRemotePortForwarding\x18$ \x01(\bH\x17R\x1denableSSHRemotePortForwarding\x88\x01\x01B\x13\n" +
|
"\x1denableSSHRemotePortForwarding\x18$ \x01(\bH\x17R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" +
|
||||||
|
"\x0edisableSSHAuth\x18% \x01(\bH\x18R\x0edisableSSHAuth\x88\x01\x01B\x13\n" +
|
||||||
"\x11_rosenpassEnabledB\x10\n" +
|
"\x11_rosenpassEnabledB\x10\n" +
|
||||||
"\x0e_interfaceNameB\x10\n" +
|
"\x0e_interfaceNameB\x10\n" +
|
||||||
"\x0e_wireguardPortB\x17\n" +
|
"\x0e_wireguardPortB\x17\n" +
|
||||||
@@ -4896,7 +5177,8 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x0e_enableSSHRootB\x10\n" +
|
"\x0e_enableSSHRootB\x10\n" +
|
||||||
"\x0e_enableSSHSFTPB\x1f\n" +
|
"\x0e_enableSSHSFTPB\x1f\n" +
|
||||||
"\x1d_enableSSHLocalPortForwardingB \n" +
|
"\x1d_enableSSHLocalPortForwardingB \n" +
|
||||||
"\x1e_enableSSHRemotePortForwarding\"\xb5\x01\n" +
|
"\x1e_enableSSHRemotePortForwardingB\x11\n" +
|
||||||
|
"\x0f_disableSSHAuth\"\xb5\x01\n" +
|
||||||
"\rLoginResponse\x12$\n" +
|
"\rLoginResponse\x12$\n" +
|
||||||
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
|
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
|
||||||
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +
|
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +
|
||||||
@@ -4929,7 +5211,7 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\fDownResponse\"P\n" +
|
"\fDownResponse\"P\n" +
|
||||||
"\x10GetConfigRequest\x12 \n" +
|
"\x10GetConfigRequest\x12 \n" +
|
||||||
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
|
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
|
||||||
"\busername\x18\x02 \x01(\tR\busername\"\x8b\b\n" +
|
"\busername\x18\x02 \x01(\tR\busername\"\xb3\b\n" +
|
||||||
"\x11GetConfigResponse\x12$\n" +
|
"\x11GetConfigResponse\x12$\n" +
|
||||||
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
|
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
|
||||||
"\n" +
|
"\n" +
|
||||||
@@ -4958,7 +5240,8 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" +
|
"\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" +
|
||||||
"\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" +
|
"\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" +
|
||||||
"\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" +
|
"\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" +
|
||||||
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\"\xfe\x05\n" +
|
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" +
|
||||||
|
"\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\"\xfe\x05\n" +
|
||||||
"\tPeerState\x12\x0e\n" +
|
"\tPeerState\x12\x0e\n" +
|
||||||
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
|
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
|
||||||
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
|
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
|
||||||
@@ -5161,7 +5444,7 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
|
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
|
||||||
"\f_profileNameB\v\n" +
|
"\f_profileNameB\v\n" +
|
||||||
"\t_username\"\x17\n" +
|
"\t_username\"\x17\n" +
|
||||||
"\x15SwitchProfileResponse\"\xcd\x0f\n" +
|
"\x15SwitchProfileResponse\"\x8d\x10\n" +
|
||||||
"\x10SetConfigRequest\x12\x1a\n" +
|
"\x10SetConfigRequest\x12\x1a\n" +
|
||||||
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
|
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
|
||||||
"\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" +
|
"\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" +
|
||||||
@@ -5198,7 +5481,8 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" +
|
"\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" +
|
||||||
"\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12A\n" +
|
"\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12A\n" +
|
||||||
"\x19enableSSHLocalPortForward\x18\x1f \x01(\bH\x14R\x19enableSSHLocalPortForward\x88\x01\x01\x12C\n" +
|
"\x19enableSSHLocalPortForward\x18\x1f \x01(\bH\x14R\x19enableSSHLocalPortForward\x88\x01\x01\x12C\n" +
|
||||||
"\x1aenableSSHRemotePortForward\x18 \x01(\bH\x15R\x1aenableSSHRemotePortForward\x88\x01\x01B\x13\n" +
|
"\x1aenableSSHRemotePortForward\x18 \x01(\bH\x15R\x1aenableSSHRemotePortForward\x88\x01\x01\x12+\n" +
|
||||||
|
"\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01B\x13\n" +
|
||||||
"\x11_rosenpassEnabledB\x10\n" +
|
"\x11_rosenpassEnabledB\x10\n" +
|
||||||
"\x0e_interfaceNameB\x10\n" +
|
"\x0e_interfaceNameB\x10\n" +
|
||||||
"\x0e_wireguardPortB\x17\n" +
|
"\x0e_wireguardPortB\x17\n" +
|
||||||
@@ -5220,7 +5504,8 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x0e_enableSSHRootB\x10\n" +
|
"\x0e_enableSSHRootB\x10\n" +
|
||||||
"\x0e_enableSSHSFTPB\x1c\n" +
|
"\x0e_enableSSHSFTPB\x1c\n" +
|
||||||
"\x1a_enableSSHLocalPortForwardB\x1d\n" +
|
"\x1a_enableSSHLocalPortForwardB\x1d\n" +
|
||||||
"\x1b_enableSSHRemotePortForward\"\x13\n" +
|
"\x1b_enableSSHRemotePortForwardB\x11\n" +
|
||||||
|
"\x0f_disableSSHAuth\"\x13\n" +
|
||||||
"\x11SetConfigResponse\"Q\n" +
|
"\x11SetConfigResponse\"Q\n" +
|
||||||
"\x11AddProfileRequest\x12\x1a\n" +
|
"\x11AddProfileRequest\x12\x1a\n" +
|
||||||
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
|
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
|
||||||
@@ -5259,7 +5544,27 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"sshHostKey\x12\x16\n" +
|
"sshHostKey\x12\x16\n" +
|
||||||
"\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" +
|
"\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" +
|
||||||
"\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" +
|
"\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" +
|
||||||
"\x05found\x18\x04 \x01(\bR\x05found*b\n" +
|
"\x05found\x18\x04 \x01(\bR\x05found\"\x17\n" +
|
||||||
|
"\x15RequestJWTAuthRequest\"\x9a\x02\n" +
|
||||||
|
"\x16RequestJWTAuthResponse\x12(\n" +
|
||||||
|
"\x0fverificationURI\x18\x01 \x01(\tR\x0fverificationURI\x128\n" +
|
||||||
|
"\x17verificationURIComplete\x18\x02 \x01(\tR\x17verificationURIComplete\x12\x1a\n" +
|
||||||
|
"\buserCode\x18\x03 \x01(\tR\buserCode\x12\x1e\n" +
|
||||||
|
"\n" +
|
||||||
|
"deviceCode\x18\x04 \x01(\tR\n" +
|
||||||
|
"deviceCode\x12\x1c\n" +
|
||||||
|
"\texpiresIn\x18\x05 \x01(\x03R\texpiresIn\x12 \n" +
|
||||||
|
"\vcachedToken\x18\x06 \x01(\tR\vcachedToken\x12 \n" +
|
||||||
|
"\vmaxTokenAge\x18\a \x01(\x03R\vmaxTokenAge\"Q\n" +
|
||||||
|
"\x13WaitJWTTokenRequest\x12\x1e\n" +
|
||||||
|
"\n" +
|
||||||
|
"deviceCode\x18\x01 \x01(\tR\n" +
|
||||||
|
"deviceCode\x12\x1a\n" +
|
||||||
|
"\buserCode\x18\x02 \x01(\tR\buserCode\"h\n" +
|
||||||
|
"\x14WaitJWTTokenResponse\x12\x14\n" +
|
||||||
|
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
|
||||||
|
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
|
||||||
|
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" +
|
||||||
"\bLogLevel\x12\v\n" +
|
"\bLogLevel\x12\v\n" +
|
||||||
"\aUNKNOWN\x10\x00\x12\t\n" +
|
"\aUNKNOWN\x10\x00\x12\t\n" +
|
||||||
"\x05PANIC\x10\x01\x12\t\n" +
|
"\x05PANIC\x10\x01\x12\t\n" +
|
||||||
@@ -5268,7 +5573,7 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x04WARN\x10\x04\x12\b\n" +
|
"\x04WARN\x10\x04\x12\b\n" +
|
||||||
"\x04INFO\x10\x05\x12\t\n" +
|
"\x04INFO\x10\x05\x12\t\n" +
|
||||||
"\x05DEBUG\x10\x06\x12\t\n" +
|
"\x05DEBUG\x10\x06\x12\t\n" +
|
||||||
"\x05TRACE\x10\a2\xeb\x10\n" +
|
"\x05TRACE\x10\a2\x8b\x12\n" +
|
||||||
"\rDaemonService\x126\n" +
|
"\rDaemonService\x126\n" +
|
||||||
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
|
||||||
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
|
||||||
@@ -5301,7 +5606,9 @@ const file_daemon_proto_rawDesc = "" +
|
|||||||
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" +
|
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" +
|
||||||
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" +
|
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" +
|
||||||
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
|
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
|
||||||
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00B\bZ\x06/protob\x06proto3"
|
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
|
||||||
|
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
|
||||||
|
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00B\bZ\x06/protob\x06proto3"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
file_daemon_proto_rawDescOnce sync.Once
|
file_daemon_proto_rawDescOnce sync.Once
|
||||||
@@ -5316,7 +5623,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
|
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
|
||||||
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 74)
|
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 78)
|
||||||
var file_daemon_proto_goTypes = []any{
|
var file_daemon_proto_goTypes = []any{
|
||||||
(LogLevel)(0), // 0: daemon.LogLevel
|
(LogLevel)(0), // 0: daemon.LogLevel
|
||||||
(SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
|
(SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
|
||||||
@@ -5392,18 +5699,22 @@ var file_daemon_proto_goTypes = []any{
|
|||||||
(*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse
|
(*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse
|
||||||
(*GetPeerSSHHostKeyRequest)(nil), // 72: daemon.GetPeerSSHHostKeyRequest
|
(*GetPeerSSHHostKeyRequest)(nil), // 72: daemon.GetPeerSSHHostKeyRequest
|
||||||
(*GetPeerSSHHostKeyResponse)(nil), // 73: daemon.GetPeerSSHHostKeyResponse
|
(*GetPeerSSHHostKeyResponse)(nil), // 73: daemon.GetPeerSSHHostKeyResponse
|
||||||
nil, // 74: daemon.Network.ResolvedIPsEntry
|
(*RequestJWTAuthRequest)(nil), // 74: daemon.RequestJWTAuthRequest
|
||||||
(*PortInfo_Range)(nil), // 75: daemon.PortInfo.Range
|
(*RequestJWTAuthResponse)(nil), // 75: daemon.RequestJWTAuthResponse
|
||||||
nil, // 76: daemon.SystemEvent.MetadataEntry
|
(*WaitJWTTokenRequest)(nil), // 76: daemon.WaitJWTTokenRequest
|
||||||
(*durationpb.Duration)(nil), // 77: google.protobuf.Duration
|
(*WaitJWTTokenResponse)(nil), // 77: daemon.WaitJWTTokenResponse
|
||||||
(*timestamppb.Timestamp)(nil), // 78: google.protobuf.Timestamp
|
nil, // 78: daemon.Network.ResolvedIPsEntry
|
||||||
|
(*PortInfo_Range)(nil), // 79: daemon.PortInfo.Range
|
||||||
|
nil, // 80: daemon.SystemEvent.MetadataEntry
|
||||||
|
(*durationpb.Duration)(nil), // 81: google.protobuf.Duration
|
||||||
|
(*timestamppb.Timestamp)(nil), // 82: google.protobuf.Timestamp
|
||||||
}
|
}
|
||||||
var file_daemon_proto_depIdxs = []int32{
|
var file_daemon_proto_depIdxs = []int32{
|
||||||
77, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
81, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||||
22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
|
||||||
78, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
82, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
|
||||||
78, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
82, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
|
||||||
77, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
81, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
|
||||||
19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
|
||||||
18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
|
||||||
17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
|
||||||
@@ -5412,8 +5723,8 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
|
||||||
52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
|
||||||
28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
|
||||||
74, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
78, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
|
||||||
75, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
79, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
|
||||||
29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
|
||||||
29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
|
||||||
30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
|
||||||
@@ -5424,10 +5735,10 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
|
||||||
1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
|
||||||
2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
|
||||||
78, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
82, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
|
||||||
76, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
80, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
|
||||||
52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
|
||||||
77, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
81, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
|
||||||
65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
|
||||||
27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
|
||||||
4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
|
||||||
@@ -5459,37 +5770,41 @@ var file_daemon_proto_depIdxs = []int32{
|
|||||||
68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
|
||||||
70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
|
||||||
72, // 59: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
72, // 59: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
|
||||||
5, // 60: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
74, // 60: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
|
||||||
7, // 61: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
76, // 61: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
|
||||||
9, // 62: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
5, // 62: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
|
||||||
11, // 63: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
7, // 63: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
|
||||||
13, // 64: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
9, // 64: daemon.DaemonService.Up:output_type -> daemon.UpResponse
|
||||||
15, // 65: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
11, // 65: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
|
||||||
24, // 66: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
13, // 66: daemon.DaemonService.Down:output_type -> daemon.DownResponse
|
||||||
26, // 67: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
15, // 67: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
|
||||||
26, // 68: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
24, // 68: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
|
||||||
31, // 69: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
26, // 69: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
33, // 70: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
26, // 70: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
|
||||||
35, // 71: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
31, // 71: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
|
||||||
37, // 72: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
33, // 72: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
|
||||||
40, // 73: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
35, // 73: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
|
||||||
42, // 74: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
37, // 74: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
|
||||||
44, // 75: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
40, // 75: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
|
||||||
46, // 76: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
42, // 76: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
|
||||||
50, // 77: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
44, // 77: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
|
||||||
52, // 78: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
46, // 78: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
|
||||||
54, // 79: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
50, // 79: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
|
||||||
56, // 80: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
52, // 80: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
|
||||||
58, // 81: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
54, // 81: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
|
||||||
60, // 82: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
56, // 82: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
|
||||||
62, // 83: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
58, // 83: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
|
||||||
64, // 84: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
60, // 84: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
|
||||||
67, // 85: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
62, // 85: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
|
||||||
69, // 86: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
64, // 86: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
|
||||||
71, // 87: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
67, // 87: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
|
||||||
73, // 88: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
69, // 88: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
|
||||||
60, // [60:89] is the sub-list for method output_type
|
71, // 89: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
|
||||||
31, // [31:60] is the sub-list for method input_type
|
73, // 90: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
|
||||||
|
75, // 91: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
|
||||||
|
77, // 92: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
|
||||||
|
62, // [62:93] is the sub-list for method output_type
|
||||||
|
31, // [31:62] is the sub-list for method input_type
|
||||||
31, // [31:31] is the sub-list for extension type_name
|
31, // [31:31] is the sub-list for extension type_name
|
||||||
31, // [31:31] is the sub-list for extension extendee
|
31, // [31:31] is the sub-list for extension extendee
|
||||||
0, // [0:31] is the sub-list for field type_name
|
0, // [0:31] is the sub-list for field type_name
|
||||||
@@ -5518,7 +5833,7 @@ func file_daemon_proto_init() {
|
|||||||
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
|
||||||
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
|
||||||
NumEnums: 3,
|
NumEnums: 3,
|
||||||
NumMessages: 74,
|
NumMessages: 78,
|
||||||
NumExtensions: 0,
|
NumExtensions: 0,
|
||||||
NumServices: 1,
|
NumServices: 1,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -87,6 +87,12 @@ service DaemonService {
|
|||||||
|
|
||||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||||
rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
|
rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
|
||||||
|
|
||||||
|
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||||
|
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
|
||||||
|
|
||||||
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
|
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
@@ -166,6 +172,7 @@ message LoginRequest {
|
|||||||
optional bool enableSSHSFTP = 34;
|
optional bool enableSSHSFTP = 34;
|
||||||
optional bool enableSSHLocalPortForwarding = 35;
|
optional bool enableSSHLocalPortForwarding = 35;
|
||||||
optional bool enableSSHRemotePortForwarding = 36;
|
optional bool enableSSHRemotePortForwarding = 36;
|
||||||
|
optional bool disableSSHAuth = 37;
|
||||||
}
|
}
|
||||||
|
|
||||||
message LoginResponse {
|
message LoginResponse {
|
||||||
@@ -268,6 +275,8 @@ message GetConfigResponse {
|
|||||||
bool enableSSHLocalPortForwarding = 22;
|
bool enableSSHLocalPortForwarding = 22;
|
||||||
|
|
||||||
bool enableSSHRemotePortForwarding = 23;
|
bool enableSSHRemotePortForwarding = 23;
|
||||||
|
|
||||||
|
bool disableSSHAuth = 25;
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerState contains the latest state of a peer
|
// PeerState contains the latest state of a peer
|
||||||
@@ -612,6 +621,7 @@ message SetConfigRequest {
|
|||||||
optional bool enableSSHSFTP = 30;
|
optional bool enableSSHSFTP = 30;
|
||||||
optional bool enableSSHLocalPortForward = 31;
|
optional bool enableSSHLocalPortForward = 31;
|
||||||
optional bool enableSSHRemotePortForward = 32;
|
optional bool enableSSHRemotePortForward = 32;
|
||||||
|
optional bool disableSSHAuth = 33;
|
||||||
}
|
}
|
||||||
|
|
||||||
message SetConfigResponse{}
|
message SetConfigResponse{}
|
||||||
@@ -681,3 +691,43 @@ message GetPeerSSHHostKeyResponse {
|
|||||||
// indicates if the SSH host key was found
|
// indicates if the SSH host key was found
|
||||||
bool found = 4;
|
bool found = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RequestJWTAuthRequest for initiating JWT authentication flow
|
||||||
|
message RequestJWTAuthRequest {
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestJWTAuthResponse contains authentication flow information
|
||||||
|
message RequestJWTAuthResponse {
|
||||||
|
// verification URI for user authentication
|
||||||
|
string verificationURI = 1;
|
||||||
|
// complete verification URI (with embedded user code)
|
||||||
|
string verificationURIComplete = 2;
|
||||||
|
// user code to enter on verification URI
|
||||||
|
string userCode = 3;
|
||||||
|
// device code for polling
|
||||||
|
string deviceCode = 4;
|
||||||
|
// expiration time in seconds
|
||||||
|
int64 expiresIn = 5;
|
||||||
|
// if a cached token is available, it will be returned here
|
||||||
|
string cachedToken = 6;
|
||||||
|
// maximum age of JWT tokens in seconds (from management server)
|
||||||
|
int64 maxTokenAge = 7;
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitJWTTokenRequest for waiting for authentication completion
|
||||||
|
message WaitJWTTokenRequest {
|
||||||
|
// device code from RequestJWTAuthResponse
|
||||||
|
string deviceCode = 1;
|
||||||
|
// user code for verification
|
||||||
|
string userCode = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitJWTTokenResponse contains the JWT token after authentication
|
||||||
|
message WaitJWTTokenResponse {
|
||||||
|
// JWT token (access token or ID token)
|
||||||
|
string token = 1;
|
||||||
|
// token type (e.g., "Bearer")
|
||||||
|
string tokenType = 2;
|
||||||
|
// expiration time in seconds
|
||||||
|
int64 expiresIn = 3;
|
||||||
|
}
|
||||||
|
|||||||
@@ -66,6 +66,10 @@ type DaemonServiceClient interface {
|
|||||||
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
|
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
|
||||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||||
GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
|
GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
|
||||||
|
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||||
|
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
|
||||||
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
|
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
type daemonServiceClient struct {
|
type daemonServiceClient struct {
|
||||||
@@ -360,6 +364,24 @@ func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeer
|
|||||||
return out, nil
|
return out, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) {
|
||||||
|
out := new(RequestJWTAuthResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) {
|
||||||
|
out := new(WaitJWTTokenResponse)
|
||||||
|
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
return out, nil
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonServiceServer is the server API for DaemonService service.
|
// DaemonServiceServer is the server API for DaemonService service.
|
||||||
// All implementations must embed UnimplementedDaemonServiceServer
|
// All implementations must embed UnimplementedDaemonServiceServer
|
||||||
// for forward compatibility
|
// for forward compatibility
|
||||||
@@ -412,6 +434,10 @@ type DaemonServiceServer interface {
|
|||||||
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
|
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
|
||||||
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
|
||||||
GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
|
GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
|
||||||
|
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||||
|
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
|
||||||
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
|
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
|
||||||
mustEmbedUnimplementedDaemonServiceServer()
|
mustEmbedUnimplementedDaemonServiceServer()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -506,6 +532,12 @@ func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeature
|
|||||||
func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
|
func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
|
||||||
return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
|
return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
|
||||||
}
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented")
|
||||||
|
}
|
||||||
|
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
|
||||||
|
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
|
||||||
|
}
|
||||||
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
|
||||||
|
|
||||||
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
|
||||||
@@ -1044,6 +1076,42 @@ func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Conte
|
|||||||
return interceptor(ctx, in, info, handler)
|
return interceptor(ctx, in, info, handler)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(RequestJWTAuthRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/RequestJWTAuth",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
|
||||||
|
in := new(WaitJWTTokenRequest)
|
||||||
|
if err := dec(in); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
if interceptor == nil {
|
||||||
|
return srv.(DaemonServiceServer).WaitJWTToken(ctx, in)
|
||||||
|
}
|
||||||
|
info := &grpc.UnaryServerInfo{
|
||||||
|
Server: srv,
|
||||||
|
FullMethod: "/daemon.DaemonService/WaitJWTToken",
|
||||||
|
}
|
||||||
|
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
|
||||||
|
return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest))
|
||||||
|
}
|
||||||
|
return interceptor(ctx, in, info, handler)
|
||||||
|
}
|
||||||
|
|
||||||
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
|
||||||
// It's only intended for direct use with grpc.RegisterService,
|
// It's only intended for direct use with grpc.RegisterService,
|
||||||
// and not to be introspected or modified (even as a copy)
|
// and not to be introspected or modified (even as a copy)
|
||||||
@@ -1163,6 +1231,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
|
|||||||
MethodName: "GetPeerSSHHostKey",
|
MethodName: "GetPeerSSHHostKey",
|
||||||
Handler: _DaemonService_GetPeerSSHHostKey_Handler,
|
Handler: _DaemonService_GetPeerSSHHostKey_Handler,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
MethodName: "RequestJWTAuth",
|
||||||
|
Handler: _DaemonService_RequestJWTAuth_Handler,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
MethodName: "WaitJWTToken",
|
||||||
|
Handler: _DaemonService_WaitJWTToken_Handler,
|
||||||
|
},
|
||||||
},
|
},
|
||||||
Streams: []grpc.StreamDesc{
|
Streams: []grpc.StreamDesc{
|
||||||
{
|
{
|
||||||
|
|||||||
73
client/server/jwt_cache.go
Normal file
73
client/server/jwt_cache.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/awnumar/memguard"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
type jwtCache struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
enclave *memguard.Enclave
|
||||||
|
expiresAt time.Time
|
||||||
|
timer *time.Timer
|
||||||
|
maxTokenSize int
|
||||||
|
}
|
||||||
|
|
||||||
|
func newJWTCache() *jwtCache {
|
||||||
|
return &jwtCache{
|
||||||
|
maxTokenSize: 8192,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *jwtCache) store(token string, maxAge time.Duration) {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
c.cleanup()
|
||||||
|
|
||||||
|
if c.timer != nil {
|
||||||
|
c.timer.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenBytes := []byte(token)
|
||||||
|
c.enclave = memguard.NewEnclave(tokenBytes)
|
||||||
|
|
||||||
|
c.expiresAt = time.Now().Add(maxAge)
|
||||||
|
|
||||||
|
c.timer = time.AfterFunc(maxAge, func() {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.cleanup()
|
||||||
|
c.timer = nil
|
||||||
|
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *jwtCache) get() (string, bool) {
|
||||||
|
c.mu.RLock()
|
||||||
|
defer c.mu.RUnlock()
|
||||||
|
|
||||||
|
if c.enclave == nil || time.Now().After(c.expiresAt) {
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
|
||||||
|
buffer, err := c.enclave.Open()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("Failed to open JWT token enclave: %v", err)
|
||||||
|
return "", false
|
||||||
|
}
|
||||||
|
defer buffer.Destroy()
|
||||||
|
|
||||||
|
token := string(buffer.Bytes())
|
||||||
|
return token, true
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup destroys the secure enclave, must be called with lock held
|
||||||
|
func (c *jwtCache) cleanup() {
|
||||||
|
if c.enclave != nil {
|
||||||
|
c.enclave = nil
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,8 +11,8 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type selectRoute struct {
|
type selectRoute struct {
|
||||||
|
|||||||
@@ -46,6 +46,9 @@ const (
|
|||||||
defaultMaxRetryTime = 14 * 24 * time.Hour
|
defaultMaxRetryTime = 14 * 24 * time.Hour
|
||||||
defaultRetryMultiplier = 1.7
|
defaultRetryMultiplier = 1.7
|
||||||
|
|
||||||
|
// JWT token cache TTL for the client daemon
|
||||||
|
defaultJWTCacheTTL = 5 * time.Minute
|
||||||
|
|
||||||
errRestoreResidualState = "failed to restore residual state: %v"
|
errRestoreResidualState = "failed to restore residual state: %v"
|
||||||
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
|
||||||
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
|
||||||
@@ -81,6 +84,8 @@ type Server struct {
|
|||||||
profileManager *profilemanager.ServiceManager
|
profileManager *profilemanager.ServiceManager
|
||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
|
|
||||||
|
jwtCache *jwtCache
|
||||||
}
|
}
|
||||||
|
|
||||||
type oauthAuthFlow struct {
|
type oauthAuthFlow struct {
|
||||||
@@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
|||||||
profileManager: profilemanager.NewServiceManager(configFile),
|
profileManager: profilemanager.NewServiceManager(configFile),
|
||||||
profilesDisabled: profilesDisabled,
|
profilesDisabled: profilesDisabled,
|
||||||
updateSettingsDisabled: updateSettingsDisabled,
|
updateSettingsDisabled: updateSettingsDisabled,
|
||||||
|
jwtCache: newJWTCache(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -370,6 +376,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
|
|||||||
config.EnableSSHSFTP = msg.EnableSSHSFTP
|
config.EnableSSHSFTP = msg.EnableSSHSFTP
|
||||||
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
|
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
|
||||||
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
|
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
|
||||||
|
if msg.DisableSSHAuth != nil {
|
||||||
|
config.DisableSSHAuth = msg.DisableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
if msg.Mtu != nil {
|
if msg.Mtu != nil {
|
||||||
mtu := uint16(*msg.Mtu)
|
mtu := uint16(*msg.Mtu)
|
||||||
@@ -486,7 +495,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
|
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
|
||||||
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
|
||||||
log.Debugf("using previous oauth flow info")
|
log.Debugf("using previous oauth flow info")
|
||||||
return &proto.LoginResponse{
|
return &proto.LoginResponse{
|
||||||
@@ -503,7 +512,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
|
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("getting a request OAuth flow failed: %v", err)
|
log.Errorf("getting a request OAuth flow failed: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1077,28 +1086,41 @@ func (s *Server) GetPeerSSHHostKey(
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.mutex.Lock()
|
s.mutex.Lock()
|
||||||
defer s.mutex.Unlock()
|
connectClient := s.connectClient
|
||||||
|
statusRecorder := s.statusRecorder
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
response := &proto.GetPeerSSHHostKeyResponse{
|
if connectClient == nil {
|
||||||
Found: false,
|
return nil, errors.New("client not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.statusRecorder == nil {
|
engine := connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, errors.New("engine not started")
|
||||||
|
}
|
||||||
|
|
||||||
|
peerAddress := req.GetPeerAddress()
|
||||||
|
hostKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||||
|
|
||||||
|
response := &proto.GetPeerSSHHostKeyResponse{
|
||||||
|
Found: found,
|
||||||
|
}
|
||||||
|
|
||||||
|
if !found {
|
||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
fullStatus := s.statusRecorder.GetFullStatus()
|
response.SshHostKey = hostKey
|
||||||
peerAddress := req.GetPeerAddress()
|
|
||||||
|
|
||||||
// Search for peer by IP or FQDN
|
if statusRecorder == nil {
|
||||||
|
return response, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
fullStatus := statusRecorder.GetFullStatus()
|
||||||
for _, peerState := range fullStatus.Peers {
|
for _, peerState := range fullStatus.Peers {
|
||||||
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
|
||||||
if len(peerState.SSHHostKey) > 0 {
|
response.PeerIP = peerState.IP
|
||||||
response.SshHostKey = peerState.SSHHostKey
|
response.PeerFQDN = peerState.FQDN
|
||||||
response.PeerIP = peerState.IP
|
|
||||||
response.PeerFQDN = peerState.FQDN
|
|
||||||
response.Found = true
|
|
||||||
}
|
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -1106,6 +1128,137 @@ func (s *Server) GetPeerSSHHostKey(
|
|||||||
return response, nil
|
return response, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getJWTCacheTTL returns the JWT cache TTL from environment variable or default
|
||||||
|
// NB_SSH_JWT_CACHE_TTL=0 disables caching
|
||||||
|
// NB_SSH_JWT_CACHE_TTL=<seconds> sets custom cache TTL
|
||||||
|
func getJWTCacheTTL() time.Duration {
|
||||||
|
envValue := os.Getenv("NB_SSH_JWT_CACHE_TTL")
|
||||||
|
if envValue == "" {
|
||||||
|
return defaultJWTCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
seconds, err := strconv.Atoi(envValue)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("invalid NB_SSH_JWT_CACHE_TTL value %s, using default: %v", envValue, defaultJWTCacheTTL)
|
||||||
|
return defaultJWTCacheTTL
|
||||||
|
}
|
||||||
|
|
||||||
|
if seconds == 0 {
|
||||||
|
log.Info("SSH JWT cache disabled via NB_SSH_JWT_CACHE_TTL=0")
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
ttl := time.Duration(seconds) * time.Second
|
||||||
|
log.Infof("SSH JWT cache TTL set to %v via NB_SSH_JWT_CACHE_TTL", ttl)
|
||||||
|
return ttl
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestJWTAuth initiates JWT authentication flow for SSH
|
||||||
|
func (s *Server) RequestJWTAuth(
|
||||||
|
ctx context.Context,
|
||||||
|
_ *proto.RequestJWTAuthRequest,
|
||||||
|
) (*proto.RequestJWTAuthResponse, error) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mutex.Lock()
|
||||||
|
config := s.config
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtCacheTTL := getJWTCacheTTL()
|
||||||
|
if jwtCacheTTL > 0 {
|
||||||
|
if cachedToken, found := s.jwtCache.get(); found {
|
||||||
|
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
|
||||||
|
|
||||||
|
return &proto.RequestJWTAuthResponse{
|
||||||
|
CachedToken: cachedToken,
|
||||||
|
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
isDesktop := isUnixRunningDesktop()
|
||||||
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop)
|
||||||
|
if err != nil {
|
||||||
|
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
|
||||||
|
if err != nil {
|
||||||
|
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mutex.Lock()
|
||||||
|
s.oauthAuthFlow.flow = oAuthFlow
|
||||||
|
s.oauthAuthFlow.info = authInfo
|
||||||
|
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
return &proto.RequestJWTAuthResponse{
|
||||||
|
VerificationURI: authInfo.VerificationURI,
|
||||||
|
VerificationURIComplete: authInfo.VerificationURIComplete,
|
||||||
|
UserCode: authInfo.UserCode,
|
||||||
|
DeviceCode: authInfo.DeviceCode,
|
||||||
|
ExpiresIn: int64(authInfo.ExpiresIn),
|
||||||
|
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// WaitJWTToken waits for JWT authentication completion
|
||||||
|
func (s *Server) WaitJWTToken(
|
||||||
|
ctx context.Context,
|
||||||
|
req *proto.WaitJWTTokenRequest,
|
||||||
|
) (*proto.WaitJWTTokenResponse, error) {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return nil, ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mutex.Lock()
|
||||||
|
oAuthFlow := s.oauthAuthFlow.flow
|
||||||
|
authInfo := s.oauthAuthFlow.info
|
||||||
|
s.mutex.Unlock()
|
||||||
|
|
||||||
|
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
|
||||||
|
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
|
||||||
|
if err != nil {
|
||||||
|
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
token := tokenInfo.GetTokenToUse()
|
||||||
|
|
||||||
|
jwtCacheTTL := getJWTCacheTTL()
|
||||||
|
if jwtCacheTTL > 0 {
|
||||||
|
s.jwtCache.store(token, jwtCacheTTL)
|
||||||
|
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
|
||||||
|
} else {
|
||||||
|
log.Debug("JWT caching disabled, not storing token")
|
||||||
|
}
|
||||||
|
|
||||||
|
s.mutex.Lock()
|
||||||
|
s.oauthAuthFlow = oauthAuthFlow{}
|
||||||
|
s.mutex.Unlock()
|
||||||
|
return &proto.WaitJWTTokenResponse{
|
||||||
|
Token: tokenInfo.GetTokenToUse(),
|
||||||
|
TokenType: tokenInfo.TokenType,
|
||||||
|
ExpiresIn: int64(tokenInfo.ExpiresIn),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isUnixRunningDesktop() bool {
|
||||||
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||||
|
}
|
||||||
|
|
||||||
func (s *Server) runProbes() {
|
func (s *Server) runProbes() {
|
||||||
if s.connectClient == nil {
|
if s.connectClient == nil {
|
||||||
return
|
return
|
||||||
@@ -1191,13 +1344,18 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
|
enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
|
disableSSHAuth := false
|
||||||
|
if s.config.DisableSSHAuth != nil {
|
||||||
|
disableSSHAuth = *s.config.DisableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.GetConfigResponse{
|
return &proto.GetConfigResponse{
|
||||||
ManagementUrl: managementURL.String(),
|
ManagementUrl: managementURL.String(),
|
||||||
PreSharedKey: preSharedKey,
|
PreSharedKey: preSharedKey,
|
||||||
AdminURL: adminURL.String(),
|
AdminURL: adminURL.String(),
|
||||||
InterfaceName: cfg.WgIface,
|
InterfaceName: cfg.WgIface,
|
||||||
WireguardPort: int64(cfg.WgPort),
|
WireguardPort: int64(cfg.WgPort),
|
||||||
Mtu: int64(cfg.MTU),
|
Mtu: int64(cfg.MTU),
|
||||||
DisableAutoConnect: cfg.DisableAutoConnect,
|
DisableAutoConnect: cfg.DisableAutoConnect,
|
||||||
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
ServerSSHAllowed: *cfg.ServerSSHAllowed,
|
||||||
RosenpassEnabled: cfg.RosenpassEnabled,
|
RosenpassEnabled: cfg.RosenpassEnabled,
|
||||||
@@ -1214,6 +1372,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
|
|||||||
EnableSSHSFTP: enableSSHSFTP,
|
EnableSSHSFTP: enableSSHSFTP,
|
||||||
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
|
||||||
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
|
||||||
|
DisableSSHAuth: disableSSHAuth,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,9 +6,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func registerStates(mgr *statemanager.Manager) {
|
func registerStates(mgr *statemanager.Manager) {
|
||||||
mgr.RegisterState(&dns.ShutdownState{})
|
mgr.RegisterState(&dns.ShutdownState{})
|
||||||
mgr.RegisterState(&systemops.ShutdownState{})
|
mgr.RegisterState(&systemops.ShutdownState{})
|
||||||
|
mgr.RegisterState(&config.ShutdownState{})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/config"
|
||||||
)
|
)
|
||||||
|
|
||||||
func registerStates(mgr *statemanager.Manager) {
|
func registerStates(mgr *statemanager.Manager) {
|
||||||
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
|
|||||||
mgr.RegisterState(&systemops.ShutdownState{})
|
mgr.RegisterState(&systemops.ShutdownState{})
|
||||||
mgr.RegisterState(&nftables.ShutdownState{})
|
mgr.RegisterState(&nftables.ShutdownState{})
|
||||||
mgr.RegisterState(&iptables.ShutdownState{})
|
mgr.RegisterState(&iptables.ShutdownState{})
|
||||||
|
mgr.RegisterState(&config.ShutdownState{})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,6 +21,8 @@ import (
|
|||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -40,12 +42,10 @@ type Client struct {
|
|||||||
windowsStdinMode uint32 // nolint:unused
|
windowsStdinMode uint32 // nolint:unused
|
||||||
}
|
}
|
||||||
|
|
||||||
// Close terminates the SSH connection
|
|
||||||
func (c *Client) Close() error {
|
func (c *Client) Close() error {
|
||||||
return c.client.Close()
|
return c.client.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
// OpenTerminal opens an interactive terminal session
|
|
||||||
func (c *Client) OpenTerminal(ctx context.Context) error {
|
func (c *Client) OpenTerminal(ctx context.Context) error {
|
||||||
session, err := c.client.NewSession()
|
session, err := c.client.NewSession()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -259,43 +259,29 @@ func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error
|
|||||||
return session, cleanup, nil
|
return session, cleanup, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Dial connects to the given ssh server with proper host key verification
|
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
|
||||||
func Dial(ctx context.Context, addr, user string) (*Client, error) {
|
func getDefaultDaemonAddr() string {
|
||||||
hostKeyCallback, err := createHostKeyCallback(addr)
|
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
|
||||||
if err != nil {
|
return addr
|
||||||
return nil, fmt.Errorf("create host key callback: %w", err)
|
|
||||||
}
|
}
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
config := &ssh.ClientConfig{
|
return DefaultDaemonAddrWindows
|
||||||
User: user,
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
HostKeyCallback: hostKeyCallback,
|
|
||||||
}
|
}
|
||||||
|
return DefaultDaemonAddr
|
||||||
return dial(ctx, "tcp", addr, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DialInsecure connects to the given ssh server without host key verification (for testing only)
|
|
||||||
func DialInsecure(ctx context.Context, addr, user string) (*Client, error) {
|
|
||||||
config := &ssh.ClientConfig{
|
|
||||||
User: user,
|
|
||||||
Timeout: 30 * time.Second,
|
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 - Only used for tests
|
|
||||||
}
|
|
||||||
|
|
||||||
return dial(ctx, "tcp", addr, config)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialOptions contains options for SSH connections
|
// DialOptions contains options for SSH connections
|
||||||
type DialOptions struct {
|
type DialOptions struct {
|
||||||
KnownHostsFile string
|
KnownHostsFile string
|
||||||
IdentityFile string
|
IdentityFile string
|
||||||
DaemonAddr string
|
DaemonAddr string
|
||||||
|
SkipCachedToken bool
|
||||||
|
InsecureSkipVerify bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialWithOptions connects to the given ssh server with specified options
|
// Dial connects to the given ssh server with specified options
|
||||||
func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
|
||||||
hostKeyCallback, err := createHostKeyCallbackWithOptions(addr, opts)
|
hostKeyCallback, err := createHostKeyCallback(opts)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create host key callback: %w", err)
|
return nil, fmt.Errorf("create host key callback: %w", err)
|
||||||
}
|
}
|
||||||
@@ -306,7 +292,6 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
|
|||||||
HostKeyCallback: hostKeyCallback,
|
HostKeyCallback: hostKeyCallback,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add SSH key authentication if identity file is specified
|
|
||||||
if opts.IdentityFile != "" {
|
if opts.IdentityFile != "" {
|
||||||
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
|
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -315,11 +300,16 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
|
|||||||
config.Auth = append(config.Auth, authMethod)
|
config.Auth = append(config.Auth, authMethod)
|
||||||
}
|
}
|
||||||
|
|
||||||
return dial(ctx, "tcp", addr, config)
|
daemonAddr := opts.DaemonAddr
|
||||||
|
if daemonAddr == "" {
|
||||||
|
daemonAddr = getDefaultDaemonAddr()
|
||||||
|
}
|
||||||
|
|
||||||
|
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
|
||||||
}
|
}
|
||||||
|
|
||||||
// dial establishes an SSH connection
|
// dialSSH establishes an SSH connection without JWT authentication
|
||||||
func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
|
||||||
dialer := &net.Dialer{}
|
dialer := &net.Dialer{}
|
||||||
conn, err := dialer.DialContext(ctx, network, addr)
|
conn, err := dialer.DialContext(ctx, network, addr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -340,143 +330,84 @@ func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (
|
|||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createHostKeyCallback creates a host key verification callback that checks daemon first, then known_hosts files
|
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
|
||||||
func createHostKeyCallback(addr string) (ssh.HostKeyCallback, error) {
|
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
|
||||||
daemonAddr := os.Getenv("NB_DAEMON_ADDR")
|
host, portStr, err := net.SplitHostPort(addr)
|
||||||
if daemonAddr == "" {
|
if err != nil {
|
||||||
if runtime.GOOS == "windows" {
|
return nil, fmt.Errorf("parse address %s: %w", addr, err)
|
||||||
daemonAddr = DefaultDaemonAddrWindows
|
|
||||||
} else {
|
|
||||||
daemonAddr = DefaultDaemonAddr
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
return createHostKeyCallbackWithDaemonAddr(addr, daemonAddr)
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||||
|
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("SSH server detection failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serverType.RequiresJWT() {
|
||||||
|
return dialSSH(ctx, network, addr, config)
|
||||||
|
}
|
||||||
|
|
||||||
|
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("request JWT token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
|
||||||
|
return dialSSH(ctx, network, addr, configWithJWT)
|
||||||
}
|
}
|
||||||
|
|
||||||
// createHostKeyCallbackWithDaemonAddr creates a host key verification callback with specified daemon address
|
// requestJWTToken requests a JWT token from the NetBird daemon
|
||||||
func createHostKeyCallbackWithDaemonAddr(addr, daemonAddr string) (ssh.HostKeyCallback, error) {
|
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
|
||||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
conn, err := connectToDaemon(daemonAddr)
|
||||||
// First try to get host key from NetBird daemon
|
if err != nil {
|
||||||
if err := verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr); err == nil {
|
return "", fmt.Errorf("connect to daemon: %w", err)
|
||||||
return nil
|
}
|
||||||
}
|
defer conn.Close()
|
||||||
|
|
||||||
// Fallback to known_hosts files
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
knownHostsFiles := getKnownHostsFiles()
|
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache)
|
||||||
|
|
||||||
var hostKeyCallbacks []ssh.HostKeyCallback
|
|
||||||
|
|
||||||
for _, file := range knownHostsFiles {
|
|
||||||
if callback, err := knownhosts.New(file); err == nil {
|
|
||||||
hostKeyCallbacks = append(hostKeyCallbacks, callback)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try each known_hosts callback
|
|
||||||
for _, callback := range hostKeyCallbacks {
|
|
||||||
if err := callback(hostname, remote, key); err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||||
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||||
client, err := connectToDaemon(daemonAddr)
|
conn, err := connectToDaemon(daemonAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := client.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
log.Debugf("daemon connection close error: %v", err)
|
log.Debugf("daemon connection close error: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
addresses := buildAddressList(hostname, remote)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses)
|
verifier := nbssh.NewDaemonHostKeyVerifier(client)
|
||||||
|
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||||
return verifyKeyWithDaemon(client, addresses, key)
|
return callback(hostname, remote, key)
|
||||||
}
|
}
|
||||||
|
|
||||||
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
|
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
addr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
conn, err := grpc.NewClient(
|
||||||
ctx,
|
addr,
|
||||||
strings.TrimPrefix(daemonAddr, "tcp://"),
|
|
||||||
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
grpc.WithTransportCredentials(insecure.NewCredentials()),
|
||||||
grpc.WithBlock(),
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to connect to NetBird daemon at %s: %v", daemonAddr, err)
|
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
|
||||||
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func buildAddressList(hostname string, remote net.Addr) []string {
|
|
||||||
addresses := []string{hostname}
|
|
||||||
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
|
||||||
if host != hostname {
|
|
||||||
addresses = append(addresses, host)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return addresses
|
|
||||||
}
|
|
||||||
|
|
||||||
func verifyKeyWithDaemon(conn *grpc.ClientConn, addresses []string, key ssh.PublicKey) error {
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
|
|
||||||
for _, addr := range addresses {
|
|
||||||
if err := checkAddressKey(client, addr, key); err == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return fmt.Errorf("SSH host key not found or does not match in NetBird daemon")
|
|
||||||
}
|
|
||||||
|
|
||||||
func checkAddressKey(client proto.DaemonServiceClient, addr string, key ssh.PublicKey) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
|
||||||
PeerAddress: addr,
|
|
||||||
})
|
|
||||||
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("daemon query error for %s: %v", addr, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if !response.GetFound() {
|
|
||||||
log.Debugf("SSH host key not found in daemon for address: %s", addr)
|
|
||||||
return fmt.Errorf("key not found")
|
|
||||||
}
|
|
||||||
|
|
||||||
return compareKeys(response.GetSshHostKey(), key, addr)
|
|
||||||
}
|
|
||||||
|
|
||||||
func compareKeys(storedKeyData []byte, presentedKey ssh.PublicKey, addr string) error {
|
|
||||||
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if presentedKey.Type() == storedKey.Type() && string(presentedKey.Marshal()) == string(storedKey.Marshal()) {
|
|
||||||
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), presentedKey.Type())
|
|
||||||
return fmt.Errorf("key mismatch")
|
|
||||||
}
|
|
||||||
|
|
||||||
// getKnownHostsFiles returns paths to known_hosts files in order of preference
|
// getKnownHostsFiles returns paths to known_hosts files in order of preference
|
||||||
func getKnownHostsFiles() []string {
|
func getKnownHostsFiles() []string {
|
||||||
var files []string
|
var files []string
|
||||||
@@ -503,8 +434,12 @@ func getKnownHostsFiles() []string {
|
|||||||
return files
|
return files
|
||||||
}
|
}
|
||||||
|
|
||||||
// createHostKeyCallbackWithOptions creates a host key verification callback with custom options
|
// createHostKeyCallback creates a host key verification callback
|
||||||
func createHostKeyCallbackWithOptions(addr string, opts DialOptions) (ssh.HostKeyCallback, error) {
|
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
|
||||||
|
if opts.InsecureSkipVerify {
|
||||||
|
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
|
||||||
|
}
|
||||||
|
|
||||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||||
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
|
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -7,29 +7,36 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
|
||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
sshserver "github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestMain handles package-level setup and cleanup
|
// TestMain handles package-level setup and cleanup
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||||
|
// This happens when running tests as non-privileged user with fallback
|
||||||
|
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||||
|
// Just exit with error to break the recursion
|
||||||
|
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
// Run tests
|
// Run tests
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
|
|
||||||
// Cleanup any created test users
|
// Cleanup any created test users
|
||||||
cleanupTestUsers()
|
testutil.CleanupTestUsers()
|
||||||
|
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
@@ -39,19 +46,14 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
|||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Generate client key pair
|
|
||||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create and start server
|
// Create and start server
|
||||||
server := sshserver.New(hostKey)
|
serverConfig := &sshserver.Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := sshserver.New(serverConfig)
|
||||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||||
|
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
serverAddr := sshserver.StartTestServer(t, server)
|
serverAddr := sshserver.StartTestServer(t, server)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := server.Stop()
|
err := server.Stop()
|
||||||
@@ -62,8 +64,10 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
currentUser := getCurrentUsername()
|
currentUser := testutil.GetTestUsername(t)
|
||||||
client, err := DialInsecure(ctx, serverAddr, currentUser)
|
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := client.Close()
|
err := client.Close()
|
||||||
@@ -75,7 +79,7 @@ func TestSSHClient_DialWithKey(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestSSHClient_CommandExecution(t *testing.T) {
|
func TestSSHClient_CommandExecution(t *testing.T) {
|
||||||
if runtime.GOOS == "windows" && isCI() {
|
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||||
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -129,20 +133,16 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Generate client key for multiple connections
|
// Generate client key for multiple connections
|
||||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = server.AddAuthorizedKey("multi-peer", string(clientPubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
const numClients = 3
|
const numClients = 3
|
||||||
clients := make([]*Client, numClients)
|
clients := make([]*Client, numClients)
|
||||||
|
|
||||||
for i := 0; i < numClients; i++ {
|
for i := 0; i < numClients; i++ {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
currentUser := getCurrentUsername()
|
currentUser := testutil.GetTestUsername(t)
|
||||||
client, err := DialInsecure(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i))
|
client, err := Dial(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i), DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
cancel()
|
cancel()
|
||||||
require.NoError(t, err, "Client %d should connect successfully", i)
|
require.NoError(t, err, "Client %d should connect successfully", i)
|
||||||
clients[i] = client
|
clients[i] = client
|
||||||
@@ -161,19 +161,14 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
err = server.AddAuthorizedKey("cancel-peer", string(clientPubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
t.Run("connection with short timeout", func(t *testing.T) {
|
t.Run("connection with short timeout", func(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
currentUser := getCurrentUsername()
|
currentUser := testutil.GetTestUsername(t)
|
||||||
_, err = DialInsecure(ctx, serverAddr, currentUser)
|
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// Check for actual timeout-related errors rather than string matching
|
// Check for actual timeout-related errors rather than string matching
|
||||||
assert.True(t,
|
assert.True(t,
|
||||||
@@ -187,8 +182,10 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
|
|||||||
t.Run("command execution cancellation", func(t *testing.T) {
|
t.Run("command execution cancellation", func(t *testing.T) {
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
currentUser := getCurrentUsername()
|
currentUser := testutil.GetTestUsername(t)
|
||||||
client, err := DialInsecure(ctx, serverAddr, currentUser)
|
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := client.Close(); err != nil {
|
if err := client.Close(); err != nil {
|
||||||
@@ -214,7 +211,11 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
|
|||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
server := sshserver.New(hostKey)
|
serverConfig := &sshserver.Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := sshserver.New(serverConfig)
|
||||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||||
|
|
||||||
serverAddr := sshserver.StartTestServer(t, server)
|
serverAddr := sshserver.StartTestServer(t, server)
|
||||||
@@ -226,10 +227,12 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
|
|||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
currentUser := getCurrentUsername()
|
currentUser := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
|
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
|
||||||
client, err := DialInsecure(ctx, serverAddr, currentUser)
|
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
if client != nil {
|
if client != nil {
|
||||||
require.NoError(t, client.Close(), "Client should close without error")
|
require.NoError(t, client.Close(), "Client should close without error")
|
||||||
@@ -282,24 +285,22 @@ func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Clie
|
|||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
serverConfig := &sshserver.Config{
|
||||||
require.NoError(t, err)
|
HostKeyPEM: hostKey,
|
||||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
JWT: nil,
|
||||||
require.NoError(t, err)
|
}
|
||||||
|
server := sshserver.New(serverConfig)
|
||||||
server := sshserver.New(hostKey)
|
|
||||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||||
|
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
serverAddr := sshserver.StartTestServer(t, server)
|
serverAddr := sshserver.StartTestServer(t, server)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
currentUser := getCurrentUsername()
|
currentUser := testutil.GetTestUsername(t)
|
||||||
client, err := DialInsecure(ctx, serverAddr, currentUser)
|
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
return server, serverAddr, client
|
return server, serverAddr, client
|
||||||
@@ -361,18 +362,14 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
|||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
serverConfig := &sshserver.Config{
|
||||||
require.NoError(t, err)
|
HostKeyPEM: hostKey,
|
||||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
JWT: nil,
|
||||||
require.NoError(t, err)
|
}
|
||||||
|
server := sshserver.New(serverConfig)
|
||||||
server := sshserver.New(hostKey)
|
|
||||||
server.SetAllowLocalPortForwarding(true)
|
server.SetAllowLocalPortForwarding(true)
|
||||||
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
server.SetAllowRootLogin(true) // Allow root/admin login for tests
|
||||||
|
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
serverAddr := sshserver.StartTestServer(t, server)
|
serverAddr := sshserver.StartTestServer(t, server)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := server.Stop()
|
err := server.Stop()
|
||||||
@@ -387,11 +384,13 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Skip if running as system account that can't do port forwarding
|
// Skip if running as system account that can't do port forwarding
|
||||||
if isSystemAccount(realUser) {
|
if testutil.IsSystemAccount(realUser) {
|
||||||
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
|
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
|
||||||
}
|
}
|
||||||
|
|
||||||
client, err := DialInsecure(ctx, serverAddr, realUser)
|
client, err := Dial(ctx, serverAddr, realUser, DialOptions{
|
||||||
|
InsecureSkipVerify: true, // Skip host key verification for test
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := client.Close(); err != nil {
|
if err := client.Close(); err != nil {
|
||||||
@@ -478,180 +477,6 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
|
|||||||
assert.Equal(t, expectedResponse, string(response))
|
assert.Equal(t, expectedResponse, string(response))
|
||||||
}
|
}
|
||||||
|
|
||||||
// getCurrentUsername returns the current username for SSH connections
|
|
||||||
func getCurrentUsername() string {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
if currentUser, err := user.Current(); err == nil {
|
|
||||||
// Check if this is a system account that can't authenticate
|
|
||||||
if isSystemAccount(currentUser.Username) {
|
|
||||||
// In CI environments, create a test user; otherwise try Administrator
|
|
||||||
if isCI() {
|
|
||||||
if testUser := getOrCreateTestUser(); testUser != "" {
|
|
||||||
return testUser
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Try Administrator first for local development
|
|
||||||
if _, err := user.Lookup("Administrator"); err == nil {
|
|
||||||
return "Administrator"
|
|
||||||
}
|
|
||||||
if testUser := getOrCreateTestUser(); testUser != "" {
|
|
||||||
return testUser
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
// On Windows, return the full domain\username for proper authentication
|
|
||||||
return currentUser.Username
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if username := os.Getenv("USER"); username != "" {
|
|
||||||
return username
|
|
||||||
}
|
|
||||||
|
|
||||||
if currentUser, err := user.Current(); err == nil {
|
|
||||||
return currentUser.Username
|
|
||||||
}
|
|
||||||
|
|
||||||
return "test-user"
|
|
||||||
}
|
|
||||||
|
|
||||||
// isCI checks if we're running in GitHub Actions CI
|
|
||||||
func isCI() bool {
|
|
||||||
// Check standard CI environment variables
|
|
||||||
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
|
|
||||||
hostname, err := os.Hostname()
|
|
||||||
if err == nil && strings.HasPrefix(hostname, "runner") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// getOrCreateTestUser creates a test user on Windows if needed
|
|
||||||
func getOrCreateTestUser() string {
|
|
||||||
testUsername := "netbird-test-user"
|
|
||||||
|
|
||||||
// Check if user already exists
|
|
||||||
if _, err := user.Lookup(testUsername); err == nil {
|
|
||||||
return testUsername
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to create the user using PowerShell
|
|
||||||
if createWindowsTestUser(testUsername) {
|
|
||||||
// Register cleanup for the test user
|
|
||||||
registerTestUserCleanup(testUsername)
|
|
||||||
return testUsername
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
var createdTestUsers = make(map[string]bool)
|
|
||||||
var testUsersToCleanup []string
|
|
||||||
|
|
||||||
// registerTestUserCleanup registers a test user for cleanup
|
|
||||||
func registerTestUserCleanup(username string) {
|
|
||||||
if !createdTestUsers[username] {
|
|
||||||
createdTestUsers[username] = true
|
|
||||||
testUsersToCleanup = append(testUsersToCleanup, username)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupTestUsers removes all created test users
|
|
||||||
func cleanupTestUsers() {
|
|
||||||
for _, username := range testUsersToCleanup {
|
|
||||||
removeWindowsTestUser(username)
|
|
||||||
}
|
|
||||||
testUsersToCleanup = nil
|
|
||||||
createdTestUsers = make(map[string]bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeWindowsTestUser removes a local user on Windows using PowerShell
|
|
||||||
func removeWindowsTestUser(username string) {
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// PowerShell command to remove a local user
|
|
||||||
psCmd := fmt.Sprintf(`
|
|
||||||
try {
|
|
||||||
Remove-LocalUser -Name "%s" -ErrorAction Stop
|
|
||||||
Write-Output "User removed successfully"
|
|
||||||
} catch {
|
|
||||||
if ($_.Exception.Message -like "*cannot be found*") {
|
|
||||||
Write-Output "User not found (already removed)"
|
|
||||||
} else {
|
|
||||||
Write-Error $_.Exception.Message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`, username)
|
|
||||||
|
|
||||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
|
|
||||||
} else {
|
|
||||||
log.Printf("Test user %s cleanup result: %s", username, string(output))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// createWindowsTestUser creates a local user on Windows using PowerShell
|
|
||||||
func createWindowsTestUser(username string) bool {
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// PowerShell command to create a local user
|
|
||||||
psCmd := fmt.Sprintf(`
|
|
||||||
try {
|
|
||||||
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
|
|
||||||
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
|
|
||||||
Add-LocalGroupMember -Group "Users" -Member "%s"
|
|
||||||
Write-Output "User created successfully"
|
|
||||||
} catch {
|
|
||||||
if ($_.Exception.Message -like "*already exists*") {
|
|
||||||
Write-Output "User already exists"
|
|
||||||
} else {
|
|
||||||
Write-Error $_.Exception.Message
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`, username, username)
|
|
||||||
|
|
||||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to create test user: %v, output: %s", err, string(output))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Test user creation result: %s", string(output))
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// isSystemAccount checks if the user is a system account that can't authenticate
|
|
||||||
func isSystemAccount(username string) bool {
|
|
||||||
systemAccounts := []string{
|
|
||||||
"system",
|
|
||||||
"NT AUTHORITY\\SYSTEM",
|
|
||||||
"NT AUTHORITY\\LOCAL SERVICE",
|
|
||||||
"NT AUTHORITY\\NETWORK SERVICE",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sysAccount := range systemAccounts {
|
|
||||||
if strings.EqualFold(username, sysAccount) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
|
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
|
||||||
func getRealCurrentUser() (string, error) {
|
func getRealCurrentUser() (string, error) {
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
|
|||||||
167
client/ssh/common.go
Normal file
167
client/ssh/common.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
package ssh
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
NetBirdSSHConfigFile = "99-netbird.conf"
|
||||||
|
|
||||||
|
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
|
||||||
|
WindowsSSHConfigDir = "ssh/ssh_config.d"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
// ErrPeerNotFound indicates the peer was not found in the network
|
||||||
|
ErrPeerNotFound = errors.New("peer not found in network")
|
||||||
|
// ErrNoStoredKey indicates the peer has no stored SSH host key
|
||||||
|
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
|
||||||
|
)
|
||||||
|
|
||||||
|
// HostKeyVerifier provides SSH host key verification
|
||||||
|
type HostKeyVerifier interface {
|
||||||
|
VerifySSHHostKey(peerAddress string, key []byte) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
|
||||||
|
type DaemonHostKeyVerifier struct {
|
||||||
|
client proto.DaemonServiceClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
|
||||||
|
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
|
||||||
|
return &DaemonHostKeyVerifier{
|
||||||
|
client: client,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
|
||||||
|
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
||||||
|
PeerAddress: peerAddress,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
if !response.GetFound() {
|
||||||
|
return ErrPeerNotFound
|
||||||
|
}
|
||||||
|
|
||||||
|
storedKeyData := response.GetSshHostKey()
|
||||||
|
|
||||||
|
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
||||||
|
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool) (string, error) {
|
||||||
|
authResponse, err := client.RequestJWTAuth(ctx, &proto.RequestJWTAuthRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("request JWT auth: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if useCache && authResponse.CachedToken != "" {
|
||||||
|
log.Debug("Using cached authentication token")
|
||||||
|
return authResponse.CachedToken, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if stderr != nil {
|
||||||
|
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
||||||
|
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
|
||||||
|
if authResponse.UserCode != "" {
|
||||||
|
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
||||||
|
}
|
||||||
|
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
||||||
|
}
|
||||||
|
|
||||||
|
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
||||||
|
DeviceCode: authResponse.DeviceCode,
|
||||||
|
UserCode: authResponse.UserCode,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("wait for JWT token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if stdout != nil {
|
||||||
|
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
|
||||||
|
}
|
||||||
|
return tokenResponse.Token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// VerifyHostKey verifies an SSH host key against stored peer key data.
|
||||||
|
// Returns nil only if the presented key matches the stored key.
|
||||||
|
// Returns ErrNoStoredKey if storedKeyData is empty.
|
||||||
|
// Returns an error if the keys don't match or if parsing fails.
|
||||||
|
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
|
||||||
|
if len(storedKeyData) == 0 {
|
||||||
|
return ErrNoStoredKey
|
||||||
|
}
|
||||||
|
|
||||||
|
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
|
||||||
|
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddJWTAuth prepends JWT password authentication to existing auth methods.
|
||||||
|
// This ensures JWT auth is tried first while preserving any existing auth methods.
|
||||||
|
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
|
||||||
|
configWithJWT := *config
|
||||||
|
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
|
||||||
|
return &configWithJWT
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
|
||||||
|
// It tries multiple addresses (hostname, IP) for the peer before failing.
|
||||||
|
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
|
||||||
|
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||||
|
addresses := buildAddressList(hostname, remote)
|
||||||
|
presentedKey := key.Marshal()
|
||||||
|
|
||||||
|
for _, addr := range addresses {
|
||||||
|
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
|
||||||
|
if errors.Is(err, ErrPeerNotFound) {
|
||||||
|
// Try other addresses for this peer
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// Verified
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// buildAddressList creates a list of addresses to check for host key verification.
|
||||||
|
// It includes the original hostname and extracts the host part from the remote address if different.
|
||||||
|
func buildAddressList(hostname string, remote net.Addr) []string {
|
||||||
|
addresses := []string{hostname}
|
||||||
|
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
||||||
|
if host != hostname {
|
||||||
|
addresses = append(addresses, host)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return addresses
|
||||||
|
}
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package config
|
package config
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bufio"
|
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
@@ -12,50 +11,41 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// EnvDisableSSHConfig is the environment variable to disable SSH config management
|
|
||||||
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
|
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
|
||||||
|
|
||||||
// EnvForceSSHConfig is the environment variable to force SSH config generation even with many peers
|
|
||||||
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
|
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
|
||||||
|
|
||||||
// MaxPeersForSSHConfig is the default maximum number of peers before SSH config generation is disabled
|
|
||||||
MaxPeersForSSHConfig = 200
|
MaxPeersForSSHConfig = 200
|
||||||
|
|
||||||
// fileWriteTimeout is the timeout for file write operations
|
|
||||||
fileWriteTimeout = 2 * time.Second
|
fileWriteTimeout = 2 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
// isSSHConfigDisabled checks if SSH config management is disabled via environment variable
|
|
||||||
func isSSHConfigDisabled() bool {
|
func isSSHConfigDisabled() bool {
|
||||||
value := os.Getenv(EnvDisableSSHConfig)
|
value := os.Getenv(EnvDisableSSHConfig)
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse as boolean, default to true if non-empty but invalid
|
|
||||||
disabled, err := strconv.ParseBool(value)
|
disabled, err := strconv.ParseBool(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If not a valid boolean, treat any non-empty value as true
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return disabled
|
return disabled
|
||||||
}
|
}
|
||||||
|
|
||||||
// isSSHConfigForced checks if SSH config generation is forced via environment variable
|
|
||||||
func isSSHConfigForced() bool {
|
func isSSHConfigForced() bool {
|
||||||
value := os.Getenv(EnvForceSSHConfig)
|
value := os.Getenv(EnvForceSSHConfig)
|
||||||
if value == "" {
|
if value == "" {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Parse as boolean, default to true if non-empty but invalid
|
|
||||||
forced, err := strconv.ParseBool(value)
|
forced, err := strconv.ParseBool(value)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// If not a valid boolean, treat any non-empty value as true
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
return forced
|
return forced
|
||||||
@@ -92,85 +82,55 @@ func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// writeFileOperationWithTimeout performs a file operation with timeout
|
|
||||||
func writeFileOperationWithTimeout(filename string, operation func() error) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
done := make(chan error, 1)
|
|
||||||
go func() {
|
|
||||||
done <- operation()
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case err := <-done:
|
|
||||||
return err
|
|
||||||
case <-ctx.Done():
|
|
||||||
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Manager handles SSH client configuration for NetBird peers
|
// Manager handles SSH client configuration for NetBird peers
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
sshConfigDir string
|
sshConfigDir string
|
||||||
sshConfigFile string
|
sshConfigFile string
|
||||||
knownHostsDir string
|
|
||||||
knownHostsFile string
|
|
||||||
userKnownHosts string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerHostKey represents a peer's SSH host key information
|
// PeerSSHInfo represents a peer's SSH configuration information
|
||||||
type PeerHostKey struct {
|
type PeerSSHInfo struct {
|
||||||
Hostname string
|
Hostname string
|
||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
HostKey ssh.PublicKey
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewManager creates a new SSH config manager
|
// New creates a new SSH config manager
|
||||||
func NewManager() *Manager {
|
func New() *Manager {
|
||||||
sshConfigDir, knownHostsDir := getSystemSSHPaths()
|
sshConfigDir := getSystemSSHConfigDir()
|
||||||
return &Manager{
|
return &Manager{
|
||||||
sshConfigDir: sshConfigDir,
|
sshConfigDir: sshConfigDir,
|
||||||
sshConfigFile: "99-netbird.conf",
|
sshConfigFile: nbssh.NetBirdSSHConfigFile,
|
||||||
knownHostsDir: knownHostsDir,
|
|
||||||
knownHostsFile: "99-netbird",
|
|
||||||
userKnownHosts: "known_hosts_netbird",
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// getSystemSSHPaths returns platform-specific SSH configuration paths
|
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
|
||||||
func getSystemSSHPaths() (configDir, knownHostsDir string) {
|
func getSystemSSHConfigDir() string {
|
||||||
switch runtime.GOOS {
|
if runtime.GOOS == "windows" {
|
||||||
case "windows":
|
return getWindowsSSHConfigDir()
|
||||||
configDir, knownHostsDir = getWindowsSSHPaths()
|
|
||||||
default:
|
|
||||||
// Unix-like systems (Linux, macOS, etc.)
|
|
||||||
configDir = "/etc/ssh/ssh_config.d"
|
|
||||||
knownHostsDir = "/etc/ssh/ssh_known_hosts.d"
|
|
||||||
}
|
}
|
||||||
return configDir, knownHostsDir
|
return nbssh.UnixSSHConfigDir
|
||||||
}
|
}
|
||||||
|
|
||||||
func getWindowsSSHPaths() (configDir, knownHostsDir string) {
|
func getWindowsSSHConfigDir() string {
|
||||||
programData := os.Getenv("PROGRAMDATA")
|
programData := os.Getenv("PROGRAMDATA")
|
||||||
if programData == "" {
|
if programData == "" {
|
||||||
programData = `C:\ProgramData`
|
programData = `C:\ProgramData`
|
||||||
}
|
}
|
||||||
configDir = filepath.Join(programData, "ssh", "ssh_config.d")
|
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
|
||||||
knownHostsDir = filepath.Join(programData, "ssh", "ssh_known_hosts.d")
|
|
||||||
return configDir, knownHostsDir
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
|
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
|
||||||
func (m *Manager) SetupSSHClientConfig(peerKeys []PeerHostKey) error {
|
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
|
||||||
if !shouldGenerateSSHConfig(len(peerKeys)) {
|
if !shouldGenerateSSHConfig(len(peers)) {
|
||||||
m.logSkipReason(len(peerKeys))
|
m.logSkipReason(len(peers))
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
knownHostsPath := m.getKnownHostsPath()
|
sshConfig, err := m.buildSSHConfig(peers)
|
||||||
sshConfig := m.buildSSHConfig(peerKeys, knownHostsPath)
|
if err != nil {
|
||||||
|
return fmt.Errorf("build SSH config: %w", err)
|
||||||
|
}
|
||||||
return m.writeSSHConfig(sshConfig)
|
return m.writeSSHConfig(sshConfig)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -183,21 +143,24 @@ func (m *Manager) logSkipReason(peerCount int) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) getKnownHostsPath() string {
|
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
|
||||||
knownHostsPath, err := m.setupKnownHostsFile()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("Failed to setup known_hosts file: %v", err)
|
|
||||||
return "/dev/null"
|
|
||||||
}
|
|
||||||
return knownHostsPath
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) buildSSHConfig(peerKeys []PeerHostKey, knownHostsPath string) string {
|
|
||||||
sshConfig := m.buildConfigHeader()
|
sshConfig := m.buildConfigHeader()
|
||||||
for _, peer := range peerKeys {
|
|
||||||
sshConfig += m.buildPeerConfig(peer, knownHostsPath)
|
var allHostPatterns []string
|
||||||
|
for _, peer := range peers {
|
||||||
|
hostPatterns := m.buildHostPatterns(peer)
|
||||||
|
allHostPatterns = append(allHostPatterns, hostPatterns...)
|
||||||
}
|
}
|
||||||
return sshConfig
|
|
||||||
|
if len(allHostPatterns) > 0 {
|
||||||
|
peerConfig, err := m.buildPeerConfig(allHostPatterns)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
sshConfig += peerConfig
|
||||||
|
}
|
||||||
|
|
||||||
|
return sshConfig, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) buildConfigHeader() string {
|
func (m *Manager) buildConfigHeader() string {
|
||||||
@@ -209,25 +172,49 @@ func (m *Manager) buildConfigHeader() string {
|
|||||||
"#\n\n"
|
"#\n\n"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) buildPeerConfig(peer PeerHostKey, knownHostsPath string) string {
|
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||||
hostPatterns := m.buildHostPatterns(peer)
|
uniquePatterns := make(map[string]bool)
|
||||||
if len(hostPatterns) == 0 {
|
var deduplicatedPatterns []string
|
||||||
return ""
|
for _, pattern := range allHostPatterns {
|
||||||
|
if !uniquePatterns[pattern] {
|
||||||
|
uniquePatterns[pattern] = true
|
||||||
|
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
hostLine := strings.Join(hostPatterns, " ")
|
execPath, err := m.getNetBirdExecutablePath()
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||||
config += " # NetBird peer-specific configuration\n"
|
|
||||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
if runtime.GOOS == "windows" {
|
||||||
config += " PasswordAuthentication yes\n"
|
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||||
config += " PubkeyAuthentication yes\n"
|
} else {
|
||||||
config += " BatchMode no\n"
|
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
|
||||||
config += m.buildHostKeyConfig(knownHostsPath)
|
}
|
||||||
config += " LogLevel ERROR\n\n"
|
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||||
return config
|
config += " PasswordAuthentication yes\n"
|
||||||
|
config += " PubkeyAuthentication yes\n"
|
||||||
|
config += " BatchMode no\n"
|
||||||
|
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||||
|
config += " StrictHostKeyChecking no\n"
|
||||||
|
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
config += " UserKnownHostsFile NUL\n"
|
||||||
|
} else {
|
||||||
|
config += " UserKnownHostsFile /dev/null\n"
|
||||||
|
}
|
||||||
|
|
||||||
|
config += " CheckHostIP no\n"
|
||||||
|
config += " LogLevel ERROR\n\n"
|
||||||
|
|
||||||
|
return config, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
|
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||||
var hostPatterns []string
|
var hostPatterns []string
|
||||||
if peer.IP != "" {
|
if peer.IP != "" {
|
||||||
hostPatterns = append(hostPatterns, peer.IP)
|
hostPatterns = append(hostPatterns, peer.IP)
|
||||||
@@ -241,280 +228,55 @@ func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
|
|||||||
return hostPatterns
|
return hostPatterns
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) buildHostKeyConfig(knownHostsPath string) string {
|
|
||||||
if knownHostsPath == "/dev/null" {
|
|
||||||
return " StrictHostKeyChecking no\n" +
|
|
||||||
" UserKnownHostsFile /dev/null\n"
|
|
||||||
}
|
|
||||||
return " StrictHostKeyChecking yes\n" +
|
|
||||||
fmt.Sprintf(" UserKnownHostsFile %s\n", knownHostsPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||||
|
|
||||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||||
log.Warnf("Failed to create SSH config directory %s: %v", m.sshConfigDir, err)
|
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||||
return m.setupUserConfig(sshConfig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||||
log.Warnf("Failed to write SSH config file %s: %v", sshConfigPath, err)
|
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||||
return m.setupUserConfig(sshConfig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// setupUserConfig creates SSH config in user's directory as fallback
|
|
||||||
func (m *Manager) setupUserConfig(sshConfig string) error {
|
|
||||||
homeDir, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("get user home directory: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
userSSHDir := filepath.Join(homeDir, ".ssh")
|
|
||||||
userConfigPath := filepath.Join(userSSHDir, "config")
|
|
||||||
|
|
||||||
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
|
|
||||||
return fmt.Errorf("create user SSH directory: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if NetBird config already exists in user config
|
|
||||||
exists, err := m.configExists(userConfigPath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("check existing config: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if exists {
|
|
||||||
log.Debugf("NetBird SSH config already exists in %s", userConfigPath)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Append NetBird config to user's SSH config with timeout
|
|
||||||
if err := writeFileOperationWithTimeout(userConfigPath, func() error {
|
|
||||||
file, err := os.OpenFile(userConfigPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open user SSH config: %w", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := file.Close(); err != nil {
|
|
||||||
log.Debugf("user SSH config file close error: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
if _, err := fmt.Fprintf(file, "\n%s", sshConfig); err != nil {
|
|
||||||
return fmt.Errorf("write to user SSH config: %w", err)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Added NetBird SSH config to user config: %s", userConfigPath)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// configExists checks if NetBird SSH config already exists
|
|
||||||
func (m *Manager) configExists(configPath string) (bool, error) {
|
|
||||||
file, err := os.Open(configPath)
|
|
||||||
if err != nil {
|
|
||||||
if os.IsNotExist(err) {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
return false, fmt.Errorf("open SSH config file: %w", err)
|
|
||||||
}
|
|
||||||
defer file.Close()
|
|
||||||
|
|
||||||
scanner := bufio.NewScanner(file)
|
|
||||||
for scanner.Scan() {
|
|
||||||
line := strings.TrimSpace(scanner.Text())
|
|
||||||
if strings.Contains(line, "NetBird SSH client configuration") {
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return false, scanner.Err()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveSSHClientConfig removes NetBird SSH configuration
|
// RemoveSSHClientConfig removes NetBird SSH configuration
|
||||||
func (m *Manager) RemoveSSHClientConfig() error {
|
func (m *Manager) RemoveSSHClientConfig() error {
|
||||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||||
|
err := os.Remove(sshConfigPath)
|
||||||
// Remove system-wide config if it exists
|
if err != nil && !os.IsNotExist(err) {
|
||||||
if err := os.Remove(sshConfigPath); err != nil && !os.IsNotExist(err) {
|
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
|
||||||
log.Warnf("Failed to remove system SSH config %s: %v", sshConfigPath, err)
|
}
|
||||||
} else if err == nil {
|
if err == nil {
|
||||||
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
|
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Also try to clean up user config
|
|
||||||
homeDir, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get user home directory: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
userConfigPath := filepath.Join(homeDir, ".ssh", "config")
|
|
||||||
if err := m.removeFromUserConfig(userConfigPath); err != nil {
|
|
||||||
log.Warnf("Failed to clean user SSH config: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeFromUserConfig removes NetBird section from user's SSH config
|
func (m *Manager) getNetBirdExecutablePath() (string, error) {
|
||||||
func (m *Manager) removeFromUserConfig(configPath string) error {
|
execPath, err := os.Executable()
|
||||||
// This is complex to implement safely, so for now just log
|
|
||||||
// In practice, the system-wide config takes precedence anyway
|
|
||||||
log.Debugf("NetBird SSH config cleanup from user config not implemented")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupKnownHostsFile creates and returns the path to NetBird known_hosts file
|
|
||||||
func (m *Manager) setupKnownHostsFile() (string, error) {
|
|
||||||
// Try system-wide known_hosts first
|
|
||||||
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
|
|
||||||
if err := os.MkdirAll(m.knownHostsDir, 0755); err == nil {
|
|
||||||
// Create empty file if it doesn't exist
|
|
||||||
if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
|
|
||||||
if err := writeFileWithTimeout(knownHostsPath, []byte("# NetBird SSH known hosts\n"), 0644); err == nil {
|
|
||||||
log.Debugf("Created NetBird known_hosts file: %s", knownHostsPath)
|
|
||||||
return knownHostsPath, nil
|
|
||||||
}
|
|
||||||
} else if err == nil {
|
|
||||||
return knownHostsPath, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Fallback to user directory
|
|
||||||
homeDir, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("get user home directory: %w", err)
|
return "", fmt.Errorf("retrieve executable path: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
userSSHDir := filepath.Join(homeDir, ".ssh")
|
realPath, err := filepath.EvalSymlinks(execPath)
|
||||||
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
|
|
||||||
return "", fmt.Errorf("create user SSH directory: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
userKnownHostsPath := filepath.Join(userSSHDir, m.userKnownHosts)
|
|
||||||
if _, err := os.Stat(userKnownHostsPath); os.IsNotExist(err) {
|
|
||||||
if err := writeFileWithTimeout(userKnownHostsPath, []byte("# NetBird SSH known hosts\n"), 0600); err != nil {
|
|
||||||
return "", fmt.Errorf("create user known_hosts file: %w", err)
|
|
||||||
}
|
|
||||||
log.Debugf("Created NetBird user known_hosts file: %s", userKnownHostsPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
return userKnownHostsPath, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// UpdatePeerHostKeys updates the known_hosts file with peer host keys
|
|
||||||
func (m *Manager) UpdatePeerHostKeys(peerKeys []PeerHostKey) error {
|
|
||||||
peerCount := len(peerKeys)
|
|
||||||
|
|
||||||
// Check if SSH config should be generated
|
|
||||||
if !shouldGenerateSSHConfig(peerCount) {
|
|
||||||
if isSSHConfigDisabled() {
|
|
||||||
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
|
|
||||||
} else {
|
|
||||||
log.Infof("SSH known_hosts update skipped: too many peers (%d > %d). Use %s=true to force.",
|
|
||||||
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
knownHostsPath, err := m.setupKnownHostsFile()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setup known_hosts file: %w", err)
|
log.Debugf("symlink resolution failed: %v", err)
|
||||||
|
return execPath, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Create updated known_hosts content - NetBird file should only contain NetBird entries
|
return realPath, nil
|
||||||
var updatedContent strings.Builder
|
|
||||||
updatedContent.WriteString("# NetBird SSH known hosts\n")
|
|
||||||
updatedContent.WriteString("# Generated automatically - do not edit manually\n\n")
|
|
||||||
|
|
||||||
// Add new NetBird entries - one entry per peer with all hostnames
|
|
||||||
for _, peerKey := range peerKeys {
|
|
||||||
entry := m.formatKnownHostsEntry(peerKey)
|
|
||||||
updatedContent.WriteString(entry)
|
|
||||||
updatedContent.WriteString("\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
// Write updated content
|
|
||||||
if err := writeFileWithTimeout(knownHostsPath, []byte(updatedContent.String()), 0644); err != nil {
|
|
||||||
return fmt.Errorf("write known_hosts file: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("Updated NetBird known_hosts with %d peer keys: %s", len(peerKeys), knownHostsPath)
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// formatKnownHostsEntry formats a peer host key as a known_hosts entry
|
// GetSSHConfigDir returns the SSH config directory path
|
||||||
func (m *Manager) formatKnownHostsEntry(peerKey PeerHostKey) string {
|
func (m *Manager) GetSSHConfigDir() string {
|
||||||
hostnames := m.getHostnameVariants(peerKey)
|
return m.sshConfigDir
|
||||||
hostnameList := strings.Join(hostnames, ",")
|
|
||||||
keyString := string(ssh.MarshalAuthorizedKey(peerKey.HostKey))
|
|
||||||
keyString = strings.TrimSpace(keyString)
|
|
||||||
return fmt.Sprintf("%s %s", hostnameList, keyString)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getHostnameVariants returns all possible hostname variants for a peer
|
// GetSSHConfigFile returns the SSH config file name
|
||||||
func (m *Manager) getHostnameVariants(peerKey PeerHostKey) []string {
|
func (m *Manager) GetSSHConfigFile() string {
|
||||||
var hostnames []string
|
return m.sshConfigFile
|
||||||
|
|
||||||
// Add IP address
|
|
||||||
if peerKey.IP != "" {
|
|
||||||
hostnames = append(hostnames, peerKey.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add FQDN
|
|
||||||
if peerKey.FQDN != "" {
|
|
||||||
hostnames = append(hostnames, peerKey.FQDN)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add hostname if different from FQDN
|
|
||||||
if peerKey.Hostname != "" && peerKey.Hostname != peerKey.FQDN {
|
|
||||||
hostnames = append(hostnames, peerKey.Hostname)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Add bracketed IP for non-standard ports (SSH standard)
|
|
||||||
if peerKey.IP != "" {
|
|
||||||
hostnames = append(hostnames, fmt.Sprintf("[%s]:22", peerKey.IP))
|
|
||||||
hostnames = append(hostnames, fmt.Sprintf("[%s]:22022", peerKey.IP))
|
|
||||||
}
|
|
||||||
|
|
||||||
return hostnames
|
|
||||||
}
|
|
||||||
|
|
||||||
// GetKnownHostsPath returns the path to the NetBird known_hosts file
|
|
||||||
func (m *Manager) GetKnownHostsPath() (string, error) {
|
|
||||||
return m.setupKnownHostsFile()
|
|
||||||
}
|
|
||||||
|
|
||||||
// RemoveKnownHostsFile removes the NetBird known_hosts file
|
|
||||||
func (m *Manager) RemoveKnownHostsFile() error {
|
|
||||||
// Remove system-wide known_hosts if it exists
|
|
||||||
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
|
|
||||||
if err := os.Remove(knownHostsPath); err != nil && !os.IsNotExist(err) {
|
|
||||||
log.Warnf("Failed to remove system known_hosts %s: %v", knownHostsPath, err)
|
|
||||||
} else if err == nil {
|
|
||||||
log.Infof("Removed NetBird known_hosts: %s", knownHostsPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Also try to clean up user known_hosts
|
|
||||||
homeDir, err := os.UserHomeDir()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to get user home directory: %v", err)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
userKnownHostsPath := filepath.Join(homeDir, ".ssh", m.userKnownHosts)
|
|
||||||
if err := os.Remove(userKnownHostsPath); err != nil && !os.IsNotExist(err) {
|
|
||||||
log.Warnf("Failed to remove user known_hosts %s: %v", userKnownHostsPath, err)
|
|
||||||
} else if err == nil {
|
|
||||||
log.Infof("Removed NetBird user known_hosts: %s", userKnownHostsPath)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,81 +10,8 @@ import (
|
|||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestManager_UpdatePeerHostKeys(t *testing.T) {
|
|
||||||
// Create temporary directory for test
|
|
||||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
|
||||||
|
|
||||||
// Override manager paths to use temp directory
|
|
||||||
manager := &Manager{
|
|
||||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
|
||||||
sshConfigFile: "99-netbird.conf",
|
|
||||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
|
||||||
knownHostsFile: "99-netbird",
|
|
||||||
userKnownHosts: "known_hosts_netbird",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Generate test host keys
|
|
||||||
hostKey1, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
pubKey1, err := ssh.ParsePrivateKey(hostKey1)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
hostKey2, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
pubKey2, err := ssh.ParsePrivateKey(hostKey2)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create test peer host keys
|
|
||||||
peerKeys := []PeerHostKey{
|
|
||||||
{
|
|
||||||
Hostname: "peer1",
|
|
||||||
IP: "100.125.1.1",
|
|
||||||
FQDN: "peer1.nb.internal",
|
|
||||||
HostKey: pubKey1.PublicKey(),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Hostname: "peer2",
|
|
||||||
IP: "100.125.1.2",
|
|
||||||
FQDN: "peer2.nb.internal",
|
|
||||||
HostKey: pubKey2.PublicKey(),
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
// Test updating known_hosts
|
|
||||||
err = manager.UpdatePeerHostKeys(peerKeys)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Verify known_hosts file was created and contains entries
|
|
||||||
knownHostsPath, err := manager.GetKnownHostsPath()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
content, err := os.ReadFile(knownHostsPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
contentStr := string(content)
|
|
||||||
assert.Contains(t, contentStr, "100.125.1.1")
|
|
||||||
assert.Contains(t, contentStr, "100.125.1.2")
|
|
||||||
assert.Contains(t, contentStr, "peer1.nb.internal")
|
|
||||||
assert.Contains(t, contentStr, "peer2.nb.internal")
|
|
||||||
assert.Contains(t, contentStr, "[100.125.1.1]:22")
|
|
||||||
assert.Contains(t, contentStr, "[100.125.1.1]:22022")
|
|
||||||
|
|
||||||
// Test updating with empty list should preserve structure
|
|
||||||
err = manager.UpdatePeerHostKeys([]PeerHostKey{})
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
content, err = os.ReadFile(knownHostsPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
assert.Contains(t, string(content), "# NetBird SSH known hosts")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_SetupSSHClientConfig(t *testing.T) {
|
func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||||
// Create temporary directory for test
|
// Create temporary directory for test
|
||||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||||
@@ -93,15 +20,25 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Override manager paths to use temp directory
|
// Override manager paths to use temp directory
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||||
sshConfigFile: "99-netbird.conf",
|
sshConfigFile: "99-netbird.conf",
|
||||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
|
||||||
knownHostsFile: "99-netbird",
|
|
||||||
userKnownHosts: "known_hosts_netbird",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test SSH config generation with empty peer keys
|
// Test SSH config generation with peers
|
||||||
err = manager.SetupSSHClientConfig(nil)
|
peers := []PeerSSHInfo{
|
||||||
|
{
|
||||||
|
Hostname: "peer1",
|
||||||
|
IP: "100.125.1.1",
|
||||||
|
FQDN: "peer1.nb.internal",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Hostname: "peer2",
|
||||||
|
IP: "100.125.1.2",
|
||||||
|
FQDN: "peer2.nb.internal",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.SetupSSHClientConfig(peers)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Read generated config
|
// Read generated config
|
||||||
@@ -111,134 +48,39 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
|
|||||||
|
|
||||||
configStr := string(content)
|
configStr := string(content)
|
||||||
|
|
||||||
// Since we now use per-peer configurations instead of domain patterns,
|
// Verify the basic SSH config structure exists
|
||||||
// we should verify the basic SSH config structure exists
|
|
||||||
assert.Contains(t, configStr, "# NetBird SSH client configuration")
|
assert.Contains(t, configStr, "# NetBird SSH client configuration")
|
||||||
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
|
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
|
||||||
|
|
||||||
// Should not contain /dev/null since we have a proper known_hosts setup
|
// Check that peer hostnames are included
|
||||||
assert.NotContains(t, configStr, "UserKnownHostsFile /dev/null")
|
assert.Contains(t, configStr, "100.125.1.1")
|
||||||
}
|
assert.Contains(t, configStr, "100.125.1.2")
|
||||||
|
assert.Contains(t, configStr, "peer1.nb.internal")
|
||||||
|
assert.Contains(t, configStr, "peer2.nb.internal")
|
||||||
|
|
||||||
func TestManager_GetHostnameVariants(t *testing.T) {
|
// Check platform-specific UserKnownHostsFile
|
||||||
manager := NewManager()
|
|
||||||
|
|
||||||
peerKey := PeerHostKey{
|
|
||||||
Hostname: "testpeer",
|
|
||||||
IP: "100.125.1.10",
|
|
||||||
FQDN: "testpeer.nb.internal",
|
|
||||||
HostKey: nil, // Not needed for this test
|
|
||||||
}
|
|
||||||
|
|
||||||
variants := manager.getHostnameVariants(peerKey)
|
|
||||||
|
|
||||||
expectedVariants := []string{
|
|
||||||
"100.125.1.10",
|
|
||||||
"testpeer.nb.internal",
|
|
||||||
"testpeer",
|
|
||||||
"[100.125.1.10]:22",
|
|
||||||
"[100.125.1.10]:22022",
|
|
||||||
}
|
|
||||||
|
|
||||||
assert.ElementsMatch(t, expectedVariants, variants)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_FormatKnownHostsEntry(t *testing.T) {
|
|
||||||
manager := NewManager()
|
|
||||||
|
|
||||||
// Generate test key
|
|
||||||
hostKeyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
parsedKey, err := ssh.ParsePrivateKey(hostKeyPEM)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peerKey := PeerHostKey{
|
|
||||||
Hostname: "testpeer",
|
|
||||||
IP: "100.125.1.10",
|
|
||||||
FQDN: "testpeer.nb.internal",
|
|
||||||
HostKey: parsedKey.PublicKey(),
|
|
||||||
}
|
|
||||||
|
|
||||||
entry := manager.formatKnownHostsEntry(peerKey)
|
|
||||||
|
|
||||||
// Should contain all hostname variants
|
|
||||||
assert.Contains(t, entry, "100.125.1.10")
|
|
||||||
assert.Contains(t, entry, "testpeer.nb.internal")
|
|
||||||
assert.Contains(t, entry, "testpeer")
|
|
||||||
assert.Contains(t, entry, "[100.125.1.10]:22")
|
|
||||||
assert.Contains(t, entry, "[100.125.1.10]:22022")
|
|
||||||
|
|
||||||
// Should contain the public key
|
|
||||||
keyString := string(ssh.MarshalAuthorizedKey(parsedKey.PublicKey()))
|
|
||||||
keyString = strings.TrimSpace(keyString)
|
|
||||||
assert.Contains(t, entry, keyString)
|
|
||||||
|
|
||||||
// Should be properly formatted (hostnames followed by key)
|
|
||||||
parts := strings.Fields(entry)
|
|
||||||
assert.GreaterOrEqual(t, len(parts), 2, "Entry should have hostnames and key parts")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestManager_DirectoryFallback(t *testing.T) {
|
|
||||||
// Create temporary directory for test where system dirs will fail
|
|
||||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
|
||||||
require.NoError(t, err)
|
|
||||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
|
||||||
|
|
||||||
// Set HOME to temp directory to control user fallback
|
|
||||||
t.Setenv("HOME", tempDir)
|
|
||||||
|
|
||||||
// Create manager with non-writable system directories
|
|
||||||
// Use paths that will fail on all systems
|
|
||||||
var failPath string
|
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
failPath = "NUL:" // Special device that can't be used as directory on Windows
|
assert.Contains(t, configStr, "UserKnownHostsFile NUL")
|
||||||
} else {
|
} else {
|
||||||
failPath = "/dev/null" // Special device that can't be used as directory on Unix
|
assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
|
||||||
}
|
}
|
||||||
|
|
||||||
manager := &Manager{
|
|
||||||
sshConfigDir: failPath + "/ssh_config.d", // Should fail
|
|
||||||
sshConfigFile: "99-netbird.conf",
|
|
||||||
knownHostsDir: failPath + "/ssh_known_hosts.d", // Should fail
|
|
||||||
knownHostsFile: "99-netbird",
|
|
||||||
userKnownHosts: "known_hosts_netbird",
|
|
||||||
}
|
|
||||||
|
|
||||||
// Should fall back to user directory
|
|
||||||
knownHostsPath, err := manager.setupKnownHostsFile()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Get the actual user home directory as determined by os.UserHomeDir()
|
|
||||||
userHome, err := os.UserHomeDir()
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
expectedUserPath := filepath.Join(userHome, ".ssh", "known_hosts_netbird")
|
|
||||||
assert.Equal(t, expectedUserPath, knownHostsPath)
|
|
||||||
|
|
||||||
// Verify file was created
|
|
||||||
_, err = os.Stat(knownHostsPath)
|
|
||||||
require.NoError(t, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestGetSystemSSHPaths(t *testing.T) {
|
func TestGetSystemSSHConfigDir(t *testing.T) {
|
||||||
configDir, knownHostsDir := getSystemSSHPaths()
|
configDir := getSystemSSHConfigDir()
|
||||||
|
|
||||||
// Paths should not be empty
|
// Path should not be empty
|
||||||
assert.NotEmpty(t, configDir)
|
assert.NotEmpty(t, configDir)
|
||||||
assert.NotEmpty(t, knownHostsDir)
|
|
||||||
|
|
||||||
// Should be absolute paths
|
// Should be an absolute path
|
||||||
assert.True(t, filepath.IsAbs(configDir))
|
assert.True(t, filepath.IsAbs(configDir))
|
||||||
assert.True(t, filepath.IsAbs(knownHostsDir))
|
|
||||||
|
|
||||||
// On Unix systems, should start with /etc
|
// On Unix systems, should start with /etc
|
||||||
// On Windows, should contain ProgramData
|
// On Windows, should contain ProgramData
|
||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
assert.Contains(t, strings.ToLower(configDir), "programdata")
|
assert.Contains(t, strings.ToLower(configDir), "programdata")
|
||||||
assert.Contains(t, strings.ToLower(knownHostsDir), "programdata")
|
|
||||||
} else {
|
} else {
|
||||||
assert.Contains(t, configDir, "/etc/ssh")
|
assert.Contains(t, configDir, "/etc/ssh")
|
||||||
assert.Contains(t, knownHostsDir, "/etc/ssh")
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -250,46 +92,28 @@ func TestManager_PeerLimit(t *testing.T) {
|
|||||||
|
|
||||||
// Override manager paths to use temp directory
|
// Override manager paths to use temp directory
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||||
sshConfigFile: "99-netbird.conf",
|
sshConfigFile: "99-netbird.conf",
|
||||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
|
||||||
knownHostsFile: "99-netbird",
|
|
||||||
userKnownHosts: "known_hosts_netbird",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate many peer keys (more than limit)
|
// Generate many peers (more than limit)
|
||||||
var peerKeys []PeerHostKey
|
var peers []PeerSSHInfo
|
||||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
peers = append(peers, PeerSSHInfo{
|
||||||
require.NoError(t, err)
|
|
||||||
pubKey, err := ssh.ParsePrivateKey(hostKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peerKeys = append(peerKeys, PeerHostKey{
|
|
||||||
Hostname: fmt.Sprintf("peer%d", i),
|
Hostname: fmt.Sprintf("peer%d", i),
|
||||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||||
HostKey: pubKey.PublicKey(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that SSH config generation is skipped when too many peers
|
// Test that SSH config generation is skipped when too many peers
|
||||||
err = manager.SetupSSHClientConfig(peerKeys)
|
err = manager.SetupSSHClientConfig(peers)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Config should not be created due to peer limit
|
// Config should not be created due to peer limit
|
||||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||||
_, err = os.Stat(configPath)
|
_, err = os.Stat(configPath)
|
||||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||||
|
|
||||||
// Test that known_hosts update is also skipped
|
|
||||||
err = manager.UpdatePeerHostKeys(peerKeys)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Known hosts should not be created due to peer limit
|
|
||||||
knownHostsPath := filepath.Join(manager.knownHostsDir, manager.knownHostsFile)
|
|
||||||
_, err = os.Stat(knownHostsPath)
|
|
||||||
assert.True(t, os.IsNotExist(err), "Known hosts should not be created with too many peers")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||||
@@ -303,31 +127,22 @@ func TestManager_ForcedSSHConfig(t *testing.T) {
|
|||||||
|
|
||||||
// Override manager paths to use temp directory
|
// Override manager paths to use temp directory
|
||||||
manager := &Manager{
|
manager := &Manager{
|
||||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||||
sshConfigFile: "99-netbird.conf",
|
sshConfigFile: "99-netbird.conf",
|
||||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
|
||||||
knownHostsFile: "99-netbird",
|
|
||||||
userKnownHosts: "known_hosts_netbird",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Generate many peer keys (more than limit)
|
// Generate many peers (more than limit)
|
||||||
var peerKeys []PeerHostKey
|
var peers []PeerSSHInfo
|
||||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
peers = append(peers, PeerSSHInfo{
|
||||||
require.NoError(t, err)
|
|
||||||
pubKey, err := ssh.ParsePrivateKey(hostKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
peerKeys = append(peerKeys, PeerHostKey{
|
|
||||||
Hostname: fmt.Sprintf("peer%d", i),
|
Hostname: fmt.Sprintf("peer%d", i),
|
||||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||||
HostKey: pubKey.PublicKey(),
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test that SSH config generation is forced despite many peers
|
// Test that SSH config generation is forced despite many peers
|
||||||
err = manager.SetupSSHClientConfig(peerKeys)
|
err = manager.SetupSSHClientConfig(peers)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Config should be created despite peer limit due to force flag
|
// Config should be created despite peer limit due to force flag
|
||||||
|
|||||||
22
client/ssh/config/shutdown_state.go
Normal file
22
client/ssh/config/shutdown_state.go
Normal file
@@ -0,0 +1,22 @@
|
|||||||
|
package config
|
||||||
|
|
||||||
|
// ShutdownState represents SSH configuration state that needs to be cleaned up.
|
||||||
|
type ShutdownState struct {
|
||||||
|
SSHConfigDir string
|
||||||
|
SSHConfigFile string
|
||||||
|
}
|
||||||
|
|
||||||
|
// Name returns the state name for the state manager.
|
||||||
|
func (s *ShutdownState) Name() string {
|
||||||
|
return "ssh_config_state"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Cleanup removes SSH client configuration files.
|
||||||
|
func (s *ShutdownState) Cleanup() error {
|
||||||
|
manager := &Manager{
|
||||||
|
sshConfigDir: s.SSHConfigDir,
|
||||||
|
sshConfigFile: s.SSHConfigFile,
|
||||||
|
}
|
||||||
|
|
||||||
|
return manager.RemoveSSHClientConfig()
|
||||||
|
}
|
||||||
99
client/ssh/detection/detection.go
Normal file
99
client/ssh/detection/detection.go
Normal file
@@ -0,0 +1,99 @@
|
|||||||
|
package detection
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bufio"
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ServerIdentifier is the base response for NetBird SSH servers
|
||||||
|
ServerIdentifier = "NetBird-SSH-Server"
|
||||||
|
// ProxyIdentifier is the base response for NetBird SSH proxy
|
||||||
|
ProxyIdentifier = "NetBird-SSH-Proxy"
|
||||||
|
// JWTRequiredMarker is appended to responses when JWT is required
|
||||||
|
JWTRequiredMarker = "NetBird-JWT-Required"
|
||||||
|
|
||||||
|
// Timeout is the timeout for SSH server detection
|
||||||
|
Timeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type ServerType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
|
||||||
|
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
|
||||||
|
ServerTypeRegular ServerType = "regular"
|
||||||
|
)
|
||||||
|
|
||||||
|
// Dialer provides network connection capabilities
|
||||||
|
type Dialer interface {
|
||||||
|
DialContext(ctx context.Context, network, address string) (net.Conn, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
// RequiresJWT checks if the server type requires JWT authentication
|
||||||
|
func (s ServerType) RequiresJWT() bool {
|
||||||
|
return s == ServerTypeNetBirdJWT
|
||||||
|
}
|
||||||
|
|
||||||
|
// ExitCode returns the exit code for the detect command
|
||||||
|
func (s ServerType) ExitCode() int {
|
||||||
|
switch s {
|
||||||
|
case ServerTypeNetBirdJWT:
|
||||||
|
return 0
|
||||||
|
case ServerTypeNetBirdNoJWT:
|
||||||
|
return 1
|
||||||
|
case ServerTypeRegular:
|
||||||
|
return 2
|
||||||
|
default:
|
||||||
|
return 2
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetectSSHServerType detects SSH server type using the provided dialer
|
||||||
|
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
|
||||||
|
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||||
|
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("SSH connection failed for detection: %v", err)
|
||||||
|
return ServerTypeRegular, nil
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
|
||||||
|
log.Debugf("set read deadline: %v", err)
|
||||||
|
return ServerTypeRegular, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
reader := bufio.NewReader(conn)
|
||||||
|
serverBanner, err := reader.ReadString('\n')
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("read SSH banner: %v", err)
|
||||||
|
return ServerTypeRegular, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
serverBanner = strings.TrimSpace(serverBanner)
|
||||||
|
log.Debugf("SSH server banner: %s", serverBanner)
|
||||||
|
|
||||||
|
if !strings.HasPrefix(serverBanner, "SSH-") {
|
||||||
|
log.Debugf("Invalid SSH banner")
|
||||||
|
return ServerTypeRegular, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if !strings.Contains(serverBanner, ServerIdentifier) {
|
||||||
|
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
|
||||||
|
return ServerTypeRegular, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(serverBanner, JWTRequiredMarker) {
|
||||||
|
return ServerTypeNetBirdJWT, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ServerTypeNetBirdNoJWT, nil
|
||||||
|
}
|
||||||
359
client/ssh/proxy/proxy.go
Normal file
359
client/ssh/proxy/proxy.go
Normal file
@@ -0,0 +1,359 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/gliderlabs/ssh"
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
|
||||||
|
sshConnectionTimeout = 120 * time.Second
|
||||||
|
// sshHandshakeTimeout is the timeout for SSH handshake completion
|
||||||
|
sshHandshakeTimeout = 30 * time.Second
|
||||||
|
|
||||||
|
jwtAuthErrorMsg = "JWT authentication: %w"
|
||||||
|
)
|
||||||
|
|
||||||
|
type SSHProxy struct {
|
||||||
|
daemonAddr string
|
||||||
|
targetHost string
|
||||||
|
targetPort int
|
||||||
|
stderr io.Writer
|
||||||
|
daemonClient proto.DaemonServiceClient
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
|
||||||
|
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
|
||||||
|
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connect to daemon: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &SSHProxy{
|
||||||
|
daemonAddr: daemonAddr,
|
||||||
|
targetHost: targetHost,
|
||||||
|
targetPort: targetPort,
|
||||||
|
stderr: stderr,
|
||||||
|
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) Connect(ctx context.Context) error {
|
||||||
|
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf(jwtAuthErrorMsg, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return p.runProxySSHServer(ctx, jwtToken)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
|
||||||
|
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
|
||||||
|
|
||||||
|
sshServer := &ssh.Server{
|
||||||
|
Handler: func(s ssh.Session) {
|
||||||
|
p.handleSSHSession(ctx, s, jwtToken)
|
||||||
|
},
|
||||||
|
ChannelHandlers: map[string]ssh.ChannelHandler{
|
||||||
|
"session": ssh.DefaultSessionHandler,
|
||||||
|
"direct-tcpip": p.directTCPIPHandler,
|
||||||
|
},
|
||||||
|
SubsystemHandlers: map[string]ssh.SubsystemHandler{
|
||||||
|
"sftp": func(s ssh.Session) {
|
||||||
|
p.sftpSubsystemHandler(s, jwtToken)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
RequestHandlers: map[string]ssh.RequestHandler{
|
||||||
|
"tcpip-forward": p.tcpipForwardHandler,
|
||||||
|
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
|
||||||
|
},
|
||||||
|
Version: serverVersion,
|
||||||
|
}
|
||||||
|
|
||||||
|
hostKey, err := generateHostKey()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("generate host key: %w", err)
|
||||||
|
}
|
||||||
|
sshServer.HostSigners = []ssh.Signer{hostKey}
|
||||||
|
|
||||||
|
conn := &stdioConn{
|
||||||
|
stdin: os.Stdin,
|
||||||
|
stdout: os.Stdout,
|
||||||
|
}
|
||||||
|
|
||||||
|
sshServer.HandleConn(conn)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
|
||||||
|
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||||
|
|
||||||
|
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = sshClient.Close() }()
|
||||||
|
|
||||||
|
serverSession, err := sshClient.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() { _ = serverSession.Close() }()
|
||||||
|
|
||||||
|
serverSession.Stdin = session
|
||||||
|
serverSession.Stdout = session
|
||||||
|
serverSession.Stderr = session.Stderr()
|
||||||
|
|
||||||
|
ptyReq, winCh, isPty := session.Pty()
|
||||||
|
if isPty {
|
||||||
|
_ = serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for win := range winCh {
|
||||||
|
_ = serverSession.WindowChange(win.Height, win.Width)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(session.Command()) > 0 {
|
||||||
|
_ = serverSession.Run(strings.Join(session.Command(), " "))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err = serverSession.Shell(); err == nil {
|
||||||
|
_ = serverSession.Wait()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateHostKey() (ssh.Signer, error) {
|
||||||
|
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate ED25519 key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := cryptossh.ParsePrivateKey(keyPEM)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return signer, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type stdioConn struct {
|
||||||
|
stdin io.Reader
|
||||||
|
stdout io.Writer
|
||||||
|
closed bool
|
||||||
|
mu sync.Mutex
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stdioConn) Read(b []byte) (n int, err error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return 0, io.EOF
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
return c.stdin.Read(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stdioConn) Write(b []byte) (n int, err error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
if c.closed {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return 0, io.ErrClosedPipe
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
return c.stdout.Write(b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stdioConn) Close() error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
c.closed = true
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stdioConn) LocalAddr() net.Addr {
|
||||||
|
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stdioConn) RemoteAddr() net.Addr {
|
||||||
|
return &net.UnixAddr{Name: "stdio", Net: "unix"}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stdioConn) SetDeadline(_ time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
|
||||||
|
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
|
||||||
|
ctx, cancel := context.WithCancel(s.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
|
||||||
|
|
||||||
|
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
|
||||||
|
_ = s.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := sshClient.Close(); err != nil {
|
||||||
|
log.Debugf("close SSH client: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
serverSession, err := sshClient.NewSession()
|
||||||
|
if err != nil {
|
||||||
|
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
|
||||||
|
_ = s.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := serverSession.Close(); err != nil {
|
||||||
|
log.Debugf("close server session: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
stdin, stdout, err := p.setupSFTPPipes(serverSession)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("setup SFTP pipes: %v", err)
|
||||||
|
_ = s.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := serverSession.RequestSubsystem("sftp"); err != nil {
|
||||||
|
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
|
||||||
|
_ = s.Exit(1)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
|
||||||
|
stdin, err := serverSession.StdinPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
stdout, err := serverSession.StdoutPipe()
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return stdin, stdout, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
|
||||||
|
copyErrCh := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(stdin, s)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("SFTP client to server copy: %v", err)
|
||||||
|
}
|
||||||
|
if err := stdin.Close(); err != nil {
|
||||||
|
log.Debugf("close stdin: %v", err)
|
||||||
|
}
|
||||||
|
copyErrCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(s, stdout)
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("SFTP server to client copy: %v", err)
|
||||||
|
}
|
||||||
|
copyErrCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
if err := serverSession.Close(); err != nil {
|
||||||
|
log.Debugf("force close server session on context cancellation: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for i := 0; i < 2; i++ {
|
||||||
|
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
|
||||||
|
log.Debugf("SFTP copy error: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := serverSession.Wait(); err != nil {
|
||||||
|
log.Debugf("SFTP session ended: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||||
|
return false, []byte("port forwarding not supported in proxy")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
|
||||||
|
config := &cryptossh.ClientConfig{
|
||||||
|
User: user,
|
||||||
|
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
|
||||||
|
Timeout: sshHandshakeTimeout,
|
||||||
|
HostKeyCallback: p.verifyHostKey,
|
||||||
|
}
|
||||||
|
|
||||||
|
dialer := &net.Dialer{
|
||||||
|
Timeout: sshConnectionTimeout,
|
||||||
|
}
|
||||||
|
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("connect to server: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
|
||||||
|
if err != nil {
|
||||||
|
_ = conn.Close()
|
||||||
|
return nil, fmt.Errorf("SSH handshake: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return cryptossh.NewClient(clientConn, chans, reqs), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
|
||||||
|
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
|
||||||
|
callback := nbssh.CreateHostKeyCallback(verifier)
|
||||||
|
return callback(hostname, remote, key)
|
||||||
|
}
|
||||||
361
client/ssh/proxy/proxy_test.go
Normal file
361
client/ssh/proxy/proxy_test.go
Normal file
@@ -0,0 +1,361 @@
|
|||||||
|
package proxy
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/server"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestMain(m *testing.M) {
|
||||||
|
if len(os.Args) > 2 && os.Args[1] == "ssh" {
|
||||||
|
if os.Args[2] == "exec" {
|
||||||
|
if len(os.Args) > 3 {
|
||||||
|
cmd := os.Args[3]
|
||||||
|
if cmd == "echo" && len(os.Args) > 4 {
|
||||||
|
fmt.Fprintln(os.Stdout, os.Args[4])
|
||||||
|
os.Exit(0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
code := m.Run()
|
||||||
|
|
||||||
|
testutil.CleanupTestUsers()
|
||||||
|
|
||||||
|
os.Exit(code)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHProxy_verifyHostKey(t *testing.T) {
|
||||||
|
t.Run("calls daemon to verify host key", func(t *testing.T) {
|
||||||
|
mockDaemon := startMockDaemon(t)
|
||||||
|
defer mockDaemon.stop()
|
||||||
|
|
||||||
|
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = grpcConn.Close() }()
|
||||||
|
|
||||||
|
proxy := &SSHProxy{
|
||||||
|
daemonAddr: mockDaemon.addr,
|
||||||
|
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||||
|
}
|
||||||
|
|
||||||
|
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
testPubKey, err := nbssh.GeneratePublicKey(testKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockDaemon.setHostKey("test-host", testPubKey)
|
||||||
|
|
||||||
|
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
|
||||||
|
assert.NoError(t, err)
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("rejects unknown host key", func(t *testing.T) {
|
||||||
|
mockDaemon := startMockDaemon(t)
|
||||||
|
defer mockDaemon.stop()
|
||||||
|
|
||||||
|
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() { _ = grpcConn.Close() }()
|
||||||
|
|
||||||
|
proxy := &SSHProxy{
|
||||||
|
daemonAddr: mockDaemon.addr,
|
||||||
|
daemonClient: proto.NewDaemonServiceClient(grpcConn),
|
||||||
|
}
|
||||||
|
|
||||||
|
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSSHProxy_Connect(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
issuer = "https://test-issuer.example.com"
|
||||||
|
audience = "test-audience"
|
||||||
|
)
|
||||||
|
|
||||||
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||||
|
defer jwksServer.Close()
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
serverConfig := &server.Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: &server.JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audience: audience,
|
||||||
|
KeysLocation: jwksURL,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
sshServer := server.New(serverConfig)
|
||||||
|
sshServer.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
sshServerAddr := server.StartTestServer(t, sshServer)
|
||||||
|
defer func() { _ = sshServer.Stop() }()
|
||||||
|
|
||||||
|
mockDaemon := startMockDaemon(t)
|
||||||
|
defer mockDaemon.stop()
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(sshServerAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
mockDaemon.setHostKey(host, hostPubKey)
|
||||||
|
|
||||||
|
validToken := generateValidJWT(t, privateKey, issuer, audience)
|
||||||
|
mockDaemon.setJWTToken(validToken)
|
||||||
|
|
||||||
|
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
clientConn, proxyConn := net.Pipe()
|
||||||
|
defer func() { _ = clientConn.Close() }()
|
||||||
|
|
||||||
|
origStdin := os.Stdin
|
||||||
|
origStdout := os.Stdout
|
||||||
|
defer func() {
|
||||||
|
os.Stdin = origStdin
|
||||||
|
os.Stdout = origStdout
|
||||||
|
}()
|
||||||
|
|
||||||
|
stdinReader, stdinWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
stdoutReader, stdoutWriter, err := os.Pipe()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
os.Stdin = stdinReader
|
||||||
|
os.Stdout = stdoutWriter
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(stdinWriter, proxyConn)
|
||||||
|
}()
|
||||||
|
go func() {
|
||||||
|
_, _ = io.Copy(proxyConn, stdoutReader)
|
||||||
|
}()
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
connectErrCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
connectErrCh <- proxyInstance.Connect(ctx)
|
||||||
|
}()
|
||||||
|
|
||||||
|
sshConfig := &cryptossh.ClientConfig{
|
||||||
|
User: testutil.GetTestUsername(t),
|
||||||
|
Auth: []cryptossh.AuthMethod{},
|
||||||
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 3 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
|
||||||
|
require.NoError(t, err, "Should connect to proxy server")
|
||||||
|
defer func() { _ = sshClientConn.Close() }()
|
||||||
|
|
||||||
|
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
|
||||||
|
|
||||||
|
session, err := sshClient.NewSession()
|
||||||
|
require.NoError(t, err, "Should create session through full proxy to backend")
|
||||||
|
|
||||||
|
outputCh := make(chan []byte, 1)
|
||||||
|
errCh := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
output, err := session.Output("echo hello-from-proxy")
|
||||||
|
outputCh <- output
|
||||||
|
errCh <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case output := <-outputCh:
|
||||||
|
err := <-errCh
|
||||||
|
require.NoError(t, err, "Command should execute successfully through proxy")
|
||||||
|
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
|
||||||
|
case <-time.After(3 * time.Second):
|
||||||
|
t.Fatal("Command execution timed out")
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = session.Close()
|
||||||
|
_ = sshClient.Close()
|
||||||
|
_ = clientConn.Close()
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockDaemonServer struct {
|
||||||
|
proto.UnimplementedDaemonServiceServer
|
||||||
|
hostKeys map[string][]byte
|
||||||
|
jwtToken string
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
|
||||||
|
key, found := m.hostKeys[req.PeerAddress]
|
||||||
|
return &proto.GetPeerSSHHostKeyResponse{
|
||||||
|
Found: found,
|
||||||
|
SshHostKey: key,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
|
||||||
|
return &proto.RequestJWTAuthResponse{
|
||||||
|
CachedToken: m.jwtToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
|
||||||
|
return &proto.WaitJWTTokenResponse{
|
||||||
|
Token: m.jwtToken,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type mockDaemon struct {
|
||||||
|
addr string
|
||||||
|
server *grpc.Server
|
||||||
|
impl *mockDaemonServer
|
||||||
|
}
|
||||||
|
|
||||||
|
func startMockDaemon(t *testing.T) *mockDaemon {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
listener, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
impl := &mockDaemonServer{
|
||||||
|
hostKeys: make(map[string][]byte),
|
||||||
|
jwtToken: "test-jwt-token",
|
||||||
|
}
|
||||||
|
|
||||||
|
grpcServer := grpc.NewServer()
|
||||||
|
proto.RegisterDaemonServiceServer(grpcServer, impl)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_ = grpcServer.Serve(listener)
|
||||||
|
}()
|
||||||
|
|
||||||
|
return &mockDaemon{
|
||||||
|
addr: listener.Addr().String(),
|
||||||
|
server: grpcServer,
|
||||||
|
impl: impl,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
|
||||||
|
m.impl.hostKeys[addr] = pubKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDaemon) setJWTToken(token string) {
|
||||||
|
m.impl.jwtToken = token
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *mockDaemon) stop() {
|
||||||
|
if m.server != nil {
|
||||||
|
m.server.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
|
||||||
|
t.Helper()
|
||||||
|
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return pubKey
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||||
|
t.Helper()
|
||||||
|
privateKey, jwksJSON := generateTestJWKS(t)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if _, err := w.Write(jwksJSON); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
return server, privateKey, server.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||||
|
t.Helper()
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKey := &privateKey.PublicKey
|
||||||
|
n := publicKey.N.Bytes()
|
||||||
|
e := publicKey.E
|
||||||
|
|
||||||
|
jwk := nbjwt.JSONWebKey{
|
||||||
|
Kty: "RSA",
|
||||||
|
Kid: "test-key-id",
|
||||||
|
Use: "sig",
|
||||||
|
N: base64.RawURLEncoding.EncodeToString(n),
|
||||||
|
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
|
||||||
|
}
|
||||||
|
|
||||||
|
jwks := nbjwt.Jwks{
|
||||||
|
Keys: []nbjwt.JSONWebKey{jwk},
|
||||||
|
}
|
||||||
|
|
||||||
|
jwksJSON, err := json.Marshal(jwks)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return privateKey, jwksJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||||
|
t.Helper()
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": issuer,
|
||||||
|
"aud": audience,
|
||||||
|
"sub": "test-user",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
token.Header["kid"] = "test-key-id"
|
||||||
|
|
||||||
|
tokenString, err := token.SignedString(privateKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return tokenString
|
||||||
|
}
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"os/user"
|
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -21,15 +20,24 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestMain handles package-level setup and cleanup
|
// TestMain handles package-level setup and cleanup
|
||||||
func TestMain(m *testing.M) {
|
func TestMain(m *testing.M) {
|
||||||
|
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
|
||||||
|
// This happens when running tests as non-privileged user with fallback
|
||||||
|
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
|
||||||
|
// Just exit with error to break the recursion
|
||||||
|
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
|
||||||
// Run tests
|
// Run tests
|
||||||
code := m.Run()
|
code := m.Run()
|
||||||
|
|
||||||
// Cleanup any created test users
|
// Cleanup any created test users
|
||||||
cleanupTestUsers()
|
testutil.CleanupTestUsers()
|
||||||
|
|
||||||
os.Exit(code)
|
os.Exit(code)
|
||||||
}
|
}
|
||||||
@@ -50,13 +58,15 @@ func TestSSHServerCompatibility(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Generate OpenSSH-compatible keys for client
|
// Generate OpenSSH-compatible keys for client
|
||||||
clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey(t)
|
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
HostKeyPEM: hostKey,
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKeyOpenSSH))
|
JWT: nil,
|
||||||
require.NoError(t, err)
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -73,7 +83,7 @@ func TestSSHServerCompatibility(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Get appropriate user for SSH connection (handle system accounts)
|
// Get appropriate user for SSH connection (handle system accounts)
|
||||||
username := getTestUsername(t)
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
t.Run("basic command execution", func(t *testing.T) {
|
t.Run("basic command execution", func(t *testing.T) {
|
||||||
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
|
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
|
||||||
@@ -113,7 +123,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username
|
|||||||
// testSSHInteractiveCommand tests interactive shell session.
|
// testSSHInteractiveCommand tests interactive shell session.
|
||||||
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
||||||
// Get appropriate user for SSH connection
|
// Get appropriate user for SSH connection
|
||||||
username := getTestUsername(t)
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -178,7 +188,7 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
|
|||||||
// testSSHPortForwarding tests port forwarding compatibility.
|
// testSSHPortForwarding tests port forwarding compatibility.
|
||||||
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
|
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
|
||||||
// Get appropriate user for SSH connection
|
// Get appropriate user for SSH connection
|
||||||
username := getTestUsername(t)
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
testServer, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -401,7 +411,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
|||||||
t.Skip("Skipping SSH feature compatibility tests in short mode")
|
t.Skip("Skipping SSH feature compatibility tests in short mode")
|
||||||
}
|
}
|
||||||
|
|
||||||
if runtime.GOOS == "windows" && isCI() {
|
if runtime.GOOS == "windows" && testutil.IsCI() {
|
||||||
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
|
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -438,13 +448,13 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
|||||||
|
|
||||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
HostKeyPEM: hostKey,
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
JWT: nil,
|
||||||
require.NoError(t, err)
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -468,7 +478,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
|
|||||||
// testCommandWithFlags tests that commands with flags work properly
|
// testCommandWithFlags tests that commands with flags work properly
|
||||||
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
||||||
// Get appropriate user for SSH connection
|
// Get appropriate user for SSH connection
|
||||||
username := getTestUsername(t)
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
// Test ls with flags
|
// Test ls with flags
|
||||||
cmd := exec.Command("ssh",
|
cmd := exec.Command("ssh",
|
||||||
@@ -495,7 +505,7 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
|
|||||||
// testEnvironmentVariables tests that environment is properly set up
|
// testEnvironmentVariables tests that environment is properly set up
|
||||||
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
||||||
// Get appropriate user for SSH connection
|
// Get appropriate user for SSH connection
|
||||||
username := getTestUsername(t)
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
cmd := exec.Command("ssh",
|
cmd := exec.Command("ssh",
|
||||||
"-i", keyFile,
|
"-i", keyFile,
|
||||||
@@ -522,7 +532,7 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
|
|||||||
// testExitCodes tests that exit codes are properly handled
|
// testExitCodes tests that exit codes are properly handled
|
||||||
func testExitCodes(t *testing.T, host, port, keyFile string) {
|
func testExitCodes(t *testing.T, host, port, keyFile string) {
|
||||||
// Get appropriate user for SSH connection
|
// Get appropriate user for SSH connection
|
||||||
username := getTestUsername(t)
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
// Test successful command (exit code 0)
|
// Test successful command (exit code 0)
|
||||||
cmd := exec.Command("ssh",
|
cmd := exec.Command("ssh",
|
||||||
@@ -567,7 +577,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get appropriate user for SSH connection
|
// Get appropriate user for SSH connection
|
||||||
username := getTestUsername(t)
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
// Set up SSH server with specific security settings
|
// Set up SSH server with specific security settings
|
||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
@@ -575,13 +585,13 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
|
|||||||
|
|
||||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
HostKeyPEM: hostKey,
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
JWT: nil,
|
||||||
require.NoError(t, err)
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -652,7 +662,7 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Get appropriate user for SSH connection
|
// Get appropriate user for SSH connection
|
||||||
username := getTestUsername(t)
|
username := testutil.GetTestUsername(t)
|
||||||
|
|
||||||
// Set up SSH server
|
// Set up SSH server
|
||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
@@ -660,13 +670,13 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
|||||||
|
|
||||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
HostKeyPEM: hostKey,
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
JWT: nil,
|
||||||
require.NoError(t, err)
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -710,171 +720,3 @@ func TestCrossPlatformCompatibility(t *testing.T) {
|
|||||||
t.Logf("Platform command output: %s", outputStr)
|
t.Logf("Platform command output: %s", outputStr)
|
||||||
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
|
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
|
||||||
}
|
}
|
||||||
|
|
||||||
// getTestUsername returns an appropriate username for testing
|
|
||||||
func getTestUsername(t *testing.T) string {
|
|
||||||
if runtime.GOOS == "windows" {
|
|
||||||
currentUser, err := user.Current()
|
|
||||||
require.NoError(t, err, "Should be able to get current user")
|
|
||||||
|
|
||||||
// Check if this is a system account that can't authenticate
|
|
||||||
if isSystemAccount(currentUser.Username) {
|
|
||||||
// In CI environments, create a test user; otherwise try Administrator
|
|
||||||
if isCI() {
|
|
||||||
if testUser := getOrCreateTestUser(t); testUser != "" {
|
|
||||||
return testUser
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
// Try Administrator first for local development
|
|
||||||
if _, err := user.Lookup("Administrator"); err == nil {
|
|
||||||
return "Administrator"
|
|
||||||
}
|
|
||||||
if testUser := getOrCreateTestUser(t); testUser != "" {
|
|
||||||
return testUser
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return currentUser.Username
|
|
||||||
}
|
|
||||||
|
|
||||||
currentUser, err := user.Current()
|
|
||||||
require.NoError(t, err, "Should be able to get current user")
|
|
||||||
return currentUser.Username
|
|
||||||
}
|
|
||||||
|
|
||||||
// isCI checks if we're running in a CI environment
|
|
||||||
func isCI() bool {
|
|
||||||
// Check standard CI environment variables
|
|
||||||
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
|
|
||||||
hostname, err := os.Hostname()
|
|
||||||
if err == nil && strings.HasPrefix(hostname, "runner") {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// isSystemAccount checks if the user is a system account that can't authenticate
|
|
||||||
func isSystemAccount(username string) bool {
|
|
||||||
systemAccounts := []string{
|
|
||||||
"system",
|
|
||||||
"NT AUTHORITY\\SYSTEM",
|
|
||||||
"NT AUTHORITY\\LOCAL SERVICE",
|
|
||||||
"NT AUTHORITY\\NETWORK SERVICE",
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, sysAccount := range systemAccounts {
|
|
||||||
if strings.EqualFold(username, sysAccount) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
var compatTestCreatedUsers = make(map[string]bool)
|
|
||||||
var compatTestUsersToCleanup []string
|
|
||||||
|
|
||||||
// registerTestUserCleanup registers a test user for cleanup
|
|
||||||
func registerTestUserCleanup(username string) {
|
|
||||||
if !compatTestCreatedUsers[username] {
|
|
||||||
compatTestCreatedUsers[username] = true
|
|
||||||
compatTestUsersToCleanup = append(compatTestUsersToCleanup, username)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupTestUsers removes all created test users
|
|
||||||
func cleanupTestUsers() {
|
|
||||||
for _, username := range compatTestUsersToCleanup {
|
|
||||||
removeWindowsTestUser(username)
|
|
||||||
}
|
|
||||||
compatTestUsersToCleanup = nil
|
|
||||||
compatTestCreatedUsers = make(map[string]bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getOrCreateTestUser creates a test user on Windows if needed
|
|
||||||
func getOrCreateTestUser(t *testing.T) string {
|
|
||||||
testUsername := "netbird-test-user"
|
|
||||||
|
|
||||||
// Check if user already exists
|
|
||||||
if _, err := user.Lookup(testUsername); err == nil {
|
|
||||||
return testUsername
|
|
||||||
}
|
|
||||||
|
|
||||||
// Try to create the user using PowerShell
|
|
||||||
if createWindowsTestUser(t, testUsername) {
|
|
||||||
// Register cleanup for the test user
|
|
||||||
registerTestUserCleanup(testUsername)
|
|
||||||
return testUsername
|
|
||||||
}
|
|
||||||
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// removeWindowsTestUser removes a local user on Windows using PowerShell
|
|
||||||
func removeWindowsTestUser(username string) {
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// PowerShell command to remove a local user
|
|
||||||
psCmd := fmt.Sprintf(`
|
|
||||||
try {
|
|
||||||
Remove-LocalUser -Name "%s" -ErrorAction Stop
|
|
||||||
Write-Output "User removed successfully"
|
|
||||||
} catch {
|
|
||||||
if ($_.Exception.Message -like "*cannot be found*") {
|
|
||||||
Write-Output "User not found (already removed)"
|
|
||||||
} else {
|
|
||||||
Write-Error $_.Exception.Message
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`, username)
|
|
||||||
|
|
||||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
|
|
||||||
} else {
|
|
||||||
log.Printf("Test user %s cleanup result: %s", username, string(output))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// createWindowsTestUser creates a local user on Windows using PowerShell
|
|
||||||
func createWindowsTestUser(t *testing.T, username string) bool {
|
|
||||||
if runtime.GOOS != "windows" {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
// PowerShell command to create a local user
|
|
||||||
psCmd := fmt.Sprintf(`
|
|
||||||
try {
|
|
||||||
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
|
|
||||||
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
|
|
||||||
Add-LocalGroupMember -Group "Users" -Member "%s"
|
|
||||||
Write-Output "User created successfully"
|
|
||||||
} catch {
|
|
||||||
if ($_.Exception.Message -like "*already exists*") {
|
|
||||||
Write-Output "User already exists"
|
|
||||||
} else {
|
|
||||||
Write-Error $_.Exception.Message
|
|
||||||
exit 1
|
|
||||||
}
|
|
||||||
}
|
|
||||||
`, username, username)
|
|
||||||
|
|
||||||
cmd := exec.Command("powershell", "-Command", psCmd)
|
|
||||||
output, err := cmd.CombinedOutput()
|
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
t.Logf("Test user creation result: %s", string(output))
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|||||||
610
client/ssh/server/jwt_test.go
Normal file
610
client/ssh/server/jwt_test.go
Normal file
@@ -0,0 +1,610 @@
|
|||||||
|
package server
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
|
"io"
|
||||||
|
"math/big"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"strconv"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang-jwt/jwt/v5"
|
||||||
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/client"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/testutil"
|
||||||
|
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestJWTEnforcement(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping JWT enforcement tests in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Set up SSH server
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Run("blocks_without_jwt", func(t *testing.T) {
|
||||||
|
jwtConfig := &JWTConfig{
|
||||||
|
Issuer: "test-issuer",
|
||||||
|
Audience: "test-audience",
|
||||||
|
KeysLocation: "test-keys",
|
||||||
|
}
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: jwtConfig,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
serverAddr := StartTestServer(t, server)
|
||||||
|
defer require.NoError(t, server.Stop())
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||||
|
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Detection failed: %v", err)
|
||||||
|
}
|
||||||
|
t.Logf("Detected server type: %s", serverType)
|
||||||
|
|
||||||
|
config := &cryptossh.ClientConfig{
|
||||||
|
User: testutil.GetTestUsername(t),
|
||||||
|
Auth: []cryptossh.AuthMethod{},
|
||||||
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||||
|
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allows_when_disabled", func(t *testing.T) {
|
||||||
|
serverConfigNoJWT := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
serverNoJWT := New(serverConfigNoJWT)
|
||||||
|
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
|
||||||
|
serverNoJWT.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
|
||||||
|
defer require.NoError(t, serverNoJWT.Stop())
|
||||||
|
|
||||||
|
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
|
||||||
|
require.NoError(t, err)
|
||||||
|
portNoJWT, err := strconv.Atoi(portStrNoJWT)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||||
|
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
|
||||||
|
assert.False(t, serverType.RequiresJWT())
|
||||||
|
|
||||||
|
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer client.Close()
|
||||||
|
})
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
|
||||||
|
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
|
||||||
|
privateKey, jwksJSON := generateTestJWKS(t)
|
||||||
|
|
||||||
|
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
w.Header().Set("Content-Type", "application/json")
|
||||||
|
if _, err := w.Write(jwksJSON); err != nil {
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
}
|
||||||
|
}))
|
||||||
|
|
||||||
|
return server, privateKey, server.URL
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
|
||||||
|
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
|
||||||
|
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
publicKey := &privateKey.PublicKey
|
||||||
|
n := publicKey.N.Bytes()
|
||||||
|
e := publicKey.E
|
||||||
|
|
||||||
|
jwk := nbjwt.JSONWebKey{
|
||||||
|
Kty: "RSA",
|
||||||
|
Kid: "test-key-id",
|
||||||
|
Use: "sig",
|
||||||
|
N: base64RawURLEncode(n),
|
||||||
|
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
|
||||||
|
}
|
||||||
|
|
||||||
|
jwks := nbjwt.Jwks{
|
||||||
|
Keys: []nbjwt.JSONWebKey{jwk},
|
||||||
|
}
|
||||||
|
|
||||||
|
jwksJSON, err := json.Marshal(jwks)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return privateKey, jwksJSON
|
||||||
|
}
|
||||||
|
|
||||||
|
func base64RawURLEncode(data []byte) string {
|
||||||
|
return base64.RawURLEncoding.EncodeToString(data)
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateValidJWT creates a valid JWT token for testing
|
||||||
|
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
|
||||||
|
claims := jwt.MapClaims{
|
||||||
|
"iss": issuer,
|
||||||
|
"aud": audience,
|
||||||
|
"sub": "test-user",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
}
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
|
||||||
|
token.Header["kid"] = "test-key-id"
|
||||||
|
|
||||||
|
tokenString, err := token.SignedString(privateKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
return tokenString
|
||||||
|
}
|
||||||
|
|
||||||
|
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
|
||||||
|
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
|
||||||
|
t.Helper()
|
||||||
|
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
|
||||||
|
func TestJWTDetection(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping JWT detection test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
jwksServer, _, jwksURL := setupJWKSServer(t)
|
||||||
|
defer jwksServer.Close()
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
const (
|
||||||
|
issuer = "https://test-issuer.example.com"
|
||||||
|
audience = "test-audience"
|
||||||
|
)
|
||||||
|
|
||||||
|
jwtConfig := &JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audience: audience,
|
||||||
|
KeysLocation: jwksURL,
|
||||||
|
}
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: jwtConfig,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
serverAddr := StartTestServer(t, server)
|
||||||
|
defer require.NoError(t, server.Stop())
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
dialer := &net.Dialer{Timeout: detection.Timeout}
|
||||||
|
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
|
||||||
|
assert.True(t, serverType.RequiresJWT())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestJWTFailClose(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping JWT fail-close tests in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||||
|
defer jwksServer.Close()
|
||||||
|
|
||||||
|
const (
|
||||||
|
issuer = "https://test-issuer.example.com"
|
||||||
|
audience = "test-audience"
|
||||||
|
)
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
tokenClaims jwt.MapClaims
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "blocks_token_missing_iat",
|
||||||
|
tokenClaims: jwt.MapClaims{
|
||||||
|
"iss": issuer,
|
||||||
|
"aud": audience,
|
||||||
|
"sub": "test-user",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_token_missing_sub",
|
||||||
|
tokenClaims: jwt.MapClaims{
|
||||||
|
"iss": issuer,
|
||||||
|
"aud": audience,
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_token_missing_iss",
|
||||||
|
tokenClaims: jwt.MapClaims{
|
||||||
|
"aud": audience,
|
||||||
|
"sub": "test-user",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_token_missing_aud",
|
||||||
|
tokenClaims: jwt.MapClaims{
|
||||||
|
"iss": issuer,
|
||||||
|
"sub": "test-user",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_token_wrong_issuer",
|
||||||
|
tokenClaims: jwt.MapClaims{
|
||||||
|
"iss": "wrong-issuer",
|
||||||
|
"aud": audience,
|
||||||
|
"sub": "test-user",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_token_wrong_audience",
|
||||||
|
tokenClaims: jwt.MapClaims{
|
||||||
|
"iss": issuer,
|
||||||
|
"aud": "wrong-audience",
|
||||||
|
"sub": "test-user",
|
||||||
|
"exp": time.Now().Add(time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Unix(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_expired_token",
|
||||||
|
tokenClaims: jwt.MapClaims{
|
||||||
|
"iss": issuer,
|
||||||
|
"aud": audience,
|
||||||
|
"sub": "test-user",
|
||||||
|
"exp": time.Now().Add(-time.Hour).Unix(),
|
||||||
|
"iat": time.Now().Add(-2 * time.Hour).Unix(),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
jwtConfig := &JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audience: audience,
|
||||||
|
KeysLocation: jwksURL,
|
||||||
|
MaxTokenAge: 3600,
|
||||||
|
}
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: jwtConfig,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
|
serverAddr := StartTestServer(t, server)
|
||||||
|
defer require.NoError(t, server.Stop())
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
|
||||||
|
token.Header["kid"] = "test-key-id"
|
||||||
|
tokenString, err := token.SignedString(privateKey)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
config := &cryptossh.ClientConfig{
|
||||||
|
User: testutil.GetTestUsername(t),
|
||||||
|
Auth: []cryptossh.AuthMethod{
|
||||||
|
cryptossh.Password(tokenString),
|
||||||
|
},
|
||||||
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||||
|
if conn != nil {
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
t.Logf("close connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Error(t, err, "Authentication should fail (fail-close)")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
|
||||||
|
func TestJWTAuthentication(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("Skipping JWT authentication tests in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
|
||||||
|
defer jwksServer.Close()
|
||||||
|
|
||||||
|
const (
|
||||||
|
issuer = "https://test-issuer.example.com"
|
||||||
|
audience = "test-audience"
|
||||||
|
)
|
||||||
|
|
||||||
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
token string
|
||||||
|
wantAuthOK bool
|
||||||
|
setupServer func(*Server)
|
||||||
|
testOperation func(*testing.T, *cryptossh.Client, string) error
|
||||||
|
wantOpSuccess bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "allows_shell_with_jwt",
|
||||||
|
token: "valid",
|
||||||
|
wantAuthOK: true,
|
||||||
|
setupServer: func(s *Server) {
|
||||||
|
s.SetAllowRootLogin(true)
|
||||||
|
},
|
||||||
|
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer session.Close()
|
||||||
|
return session.Shell()
|
||||||
|
},
|
||||||
|
wantOpSuccess: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "rejects_invalid_token",
|
||||||
|
token: "invalid",
|
||||||
|
wantAuthOK: false,
|
||||||
|
setupServer: func(s *Server) {
|
||||||
|
s.SetAllowRootLogin(true)
|
||||||
|
},
|
||||||
|
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
output, err := session.CombinedOutput("echo test")
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Command output: %s", string(output))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantOpSuccess: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_shell_without_jwt",
|
||||||
|
token: "",
|
||||||
|
wantAuthOK: false,
|
||||||
|
setupServer: func(s *Server) {
|
||||||
|
s.SetAllowRootLogin(true)
|
||||||
|
},
|
||||||
|
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
output, err := session.CombinedOutput("echo test")
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Command output: %s", string(output))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantOpSuccess: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_command_without_jwt",
|
||||||
|
token: "",
|
||||||
|
wantAuthOK: false,
|
||||||
|
setupServer: func(s *Server) {
|
||||||
|
s.SetAllowRootLogin(true)
|
||||||
|
},
|
||||||
|
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
output, err := session.CombinedOutput("ls")
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Command output: %s", string(output))
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
wantOpSuccess: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows_sftp_with_jwt",
|
||||||
|
token: "valid",
|
||||||
|
wantAuthOK: true,
|
||||||
|
setupServer: func(s *Server) {
|
||||||
|
s.SetAllowRootLogin(true)
|
||||||
|
s.SetAllowSFTP(true)
|
||||||
|
},
|
||||||
|
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
session.Stdout = io.Discard
|
||||||
|
session.Stderr = io.Discard
|
||||||
|
return session.RequestSubsystem("sftp")
|
||||||
|
},
|
||||||
|
wantOpSuccess: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_sftp_without_jwt",
|
||||||
|
token: "",
|
||||||
|
wantAuthOK: false,
|
||||||
|
setupServer: func(s *Server) {
|
||||||
|
s.SetAllowRootLogin(true)
|
||||||
|
s.SetAllowSFTP(true)
|
||||||
|
},
|
||||||
|
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||||
|
session, err := conn.NewSession()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer session.Close()
|
||||||
|
|
||||||
|
session.Stdout = io.Discard
|
||||||
|
session.Stderr = io.Discard
|
||||||
|
err = session.RequestSubsystem("sftp")
|
||||||
|
if err == nil {
|
||||||
|
err = session.Wait()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
wantOpSuccess: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "allows_port_forward_with_jwt",
|
||||||
|
token: "valid",
|
||||||
|
wantAuthOK: true,
|
||||||
|
setupServer: func(s *Server) {
|
||||||
|
s.SetAllowRootLogin(true)
|
||||||
|
s.SetAllowRemotePortForwarding(true)
|
||||||
|
},
|
||||||
|
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||||
|
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if ln != nil {
|
||||||
|
defer ln.Close()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
wantOpSuccess: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "blocks_port_forward_without_jwt",
|
||||||
|
token: "",
|
||||||
|
wantAuthOK: false,
|
||||||
|
setupServer: func(s *Server) {
|
||||||
|
s.SetAllowRootLogin(true)
|
||||||
|
s.SetAllowLocalPortForwarding(true)
|
||||||
|
},
|
||||||
|
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
|
||||||
|
ln, err := conn.Listen("tcp", "127.0.0.1:0")
|
||||||
|
if ln != nil {
|
||||||
|
defer ln.Close()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
},
|
||||||
|
wantOpSuccess: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
jwtConfig := &JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audience: audience,
|
||||||
|
KeysLocation: jwksURL,
|
||||||
|
}
|
||||||
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: jwtConfig,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
if tc.setupServer != nil {
|
||||||
|
tc.setupServer(server)
|
||||||
|
}
|
||||||
|
|
||||||
|
serverAddr := StartTestServer(t, server)
|
||||||
|
defer require.NoError(t, server.Stop())
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(serverAddr)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var authMethods []cryptossh.AuthMethod
|
||||||
|
if tc.token == "valid" {
|
||||||
|
token := generateValidJWT(t, privateKey, issuer, audience)
|
||||||
|
authMethods = []cryptossh.AuthMethod{
|
||||||
|
cryptossh.Password(token),
|
||||||
|
}
|
||||||
|
} else if tc.token == "invalid" {
|
||||||
|
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
|
||||||
|
authMethods = []cryptossh.AuthMethod{
|
||||||
|
cryptossh.Password(invalidToken),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config := &cryptossh.ClientConfig{
|
||||||
|
User: testutil.GetTestUsername(t),
|
||||||
|
Auth: authMethods,
|
||||||
|
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
|
||||||
|
Timeout: 2 * time.Second,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
|
||||||
|
if tc.wantAuthOK {
|
||||||
|
require.NoError(t, err, "JWT authentication should succeed")
|
||||||
|
} else if err != nil {
|
||||||
|
t.Logf("Connection failed as expected: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if conn != nil {
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
t.Logf("close connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
err = tc.testOperation(t, conn, serverAddr)
|
||||||
|
if tc.wantOpSuccess {
|
||||||
|
require.NoError(t, err, "Operation should succeed")
|
||||||
|
} else {
|
||||||
|
assert.Error(t, err, "Operation should fail")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -2,18 +2,27 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"encoding/base64"
|
||||||
|
"encoding/json"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/gliderlabs/ssh"
|
"github.com/gliderlabs/ssh"
|
||||||
|
gojwt "github.com/golang-jwt/jwt/v5"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
|
"github.com/netbirdio/netbird/management/server/auth/jwt"
|
||||||
|
nbcontext "github.com/netbirdio/netbird/management/server/context"
|
||||||
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
|
||||||
@@ -27,6 +36,9 @@ const (
|
|||||||
errExitSession = "exit session error: %v"
|
errExitSession = "exit session error: %v"
|
||||||
|
|
||||||
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
msgPrivilegedUserDisabled = "privileged user login is disabled"
|
||||||
|
|
||||||
|
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
|
||||||
|
DefaultJWTMaxTokenAge = 5 * 60
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -69,7 +81,6 @@ func (e *UserNotFoundError) Unwrap() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// safeLogCommand returns a safe representation of the command for logging
|
// safeLogCommand returns a safe representation of the command for logging
|
||||||
// Only logs the first argument to avoid leaking sensitive information
|
|
||||||
func safeLogCommand(cmd []string) string {
|
func safeLogCommand(cmd []string) string {
|
||||||
if len(cmd) == 0 {
|
if len(cmd) == 0 {
|
||||||
return "<empty>"
|
return "<empty>"
|
||||||
@@ -80,17 +91,14 @@ func safeLogCommand(cmd []string) string {
|
|||||||
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sshConnectionState tracks the state of an SSH connection
|
|
||||||
type sshConnectionState struct {
|
type sshConnectionState struct {
|
||||||
hasActivePortForward bool
|
hasActivePortForward bool
|
||||||
username string
|
username string
|
||||||
remoteAddr string
|
remoteAddr string
|
||||||
}
|
}
|
||||||
|
|
||||||
// Server is the SSH server implementation
|
|
||||||
type Server struct {
|
type Server struct {
|
||||||
sshServer *ssh.Server
|
sshServer *ssh.Server
|
||||||
authorizedKeys map[string]ssh.PublicKey
|
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
hostKeyPEM []byte
|
hostKeyPEM []byte
|
||||||
sessions map[SessionKey]ssh.Session
|
sessions map[SessionKey]ssh.Session
|
||||||
@@ -100,30 +108,53 @@ type Server struct {
|
|||||||
allowRemotePortForwarding bool
|
allowRemotePortForwarding bool
|
||||||
allowRootLogin bool
|
allowRootLogin bool
|
||||||
allowSFTP bool
|
allowSFTP bool
|
||||||
|
jwtEnabled bool
|
||||||
|
|
||||||
netstackNet *netstack.Net
|
netstackNet *netstack.Net
|
||||||
|
|
||||||
wgAddress wgaddr.Address
|
wgAddress wgaddr.Address
|
||||||
ifIdx int
|
|
||||||
|
|
||||||
remoteForwardListeners map[ForwardKey]net.Listener
|
remoteForwardListeners map[ForwardKey]net.Listener
|
||||||
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
|
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
|
||||||
|
|
||||||
|
jwtValidator *jwt.Validator
|
||||||
|
jwtExtractor *jwt.ClaimsExtractor
|
||||||
|
jwtConfig *JWTConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
// New creates an SSH server instance with the provided host key
|
type JWTConfig struct {
|
||||||
func New(hostKeyPEM []byte) *Server {
|
Issuer string
|
||||||
return &Server{
|
Audience string
|
||||||
|
KeysLocation string
|
||||||
|
MaxTokenAge int64
|
||||||
|
}
|
||||||
|
|
||||||
|
// Config contains all SSH server configuration options
|
||||||
|
type Config struct {
|
||||||
|
// JWT authentication configuration. If nil, JWT authentication is disabled
|
||||||
|
JWT *JWTConfig
|
||||||
|
|
||||||
|
// HostKey is the SSH server host key in PEM format
|
||||||
|
HostKeyPEM []byte
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates an SSH server instance with the provided host key and optional JWT configuration
|
||||||
|
// If jwtConfig is nil, JWT authentication is disabled
|
||||||
|
func New(config *Config) *Server {
|
||||||
|
s := &Server{
|
||||||
mu: sync.RWMutex{},
|
mu: sync.RWMutex{},
|
||||||
hostKeyPEM: hostKeyPEM,
|
hostKeyPEM: config.HostKeyPEM,
|
||||||
authorizedKeys: make(map[string]ssh.PublicKey),
|
|
||||||
sessions: make(map[SessionKey]ssh.Session),
|
sessions: make(map[SessionKey]ssh.Session),
|
||||||
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
remoteForwardListeners: make(map[ForwardKey]net.Listener),
|
||||||
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
|
||||||
|
jwtEnabled: config.JWT != nil,
|
||||||
|
jwtConfig: config.JWT,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return s
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start runs the SSH server, automatically detecting netstack vs standard networking
|
// Start runs the SSH server
|
||||||
// Does all setup synchronously, then starts serving in a goroutine and returns immediately
|
|
||||||
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -139,7 +170,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
|||||||
|
|
||||||
sshServer, err := s.createSSHServer(ln.Addr())
|
sshServer, err := s.createSSHServer(ln.Addr())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
s.cleanupOnError(ln)
|
s.closeListener(ln)
|
||||||
return fmt.Errorf("create SSH server: %w", err)
|
return fmt.Errorf("create SSH server: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -154,7 +185,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// createListener creates a network listener based on netstack vs standard networking
|
|
||||||
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
|
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
|
||||||
if s.netstackNet != nil {
|
if s.netstackNet != nil {
|
||||||
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
|
||||||
@@ -173,22 +203,15 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L
|
|||||||
return ln, addr.String(), nil
|
return ln, addr.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// closeListener safely closes a listener
|
|
||||||
func (s *Server) closeListener(ln net.Listener) {
|
func (s *Server) closeListener(ln net.Listener) {
|
||||||
|
if ln == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
if err := ln.Close(); err != nil {
|
if err := ln.Close(); err != nil {
|
||||||
log.Debugf("listener close error: %v", err)
|
log.Debugf("listener close error: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupOnError cleans up resources when SSH server creation fails
|
|
||||||
func (s *Server) cleanupOnError(ln net.Listener) {
|
|
||||||
if s.ifIdx == 0 || ln == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
s.closeListener(ln)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop closes the SSH server
|
// Stop closes the SSH server
|
||||||
func (s *Server) Stop() error {
|
func (s *Server) Stop() error {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@@ -207,28 +230,6 @@ func (s *Server) Stop() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemoveAuthorizedKey removes the SSH key for a peer
|
|
||||||
func (s *Server) RemoveAuthorizedKey(peer string) {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
delete(s.authorizedKeys, peer)
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddAuthorizedKey adds an SSH key for a peer
|
|
||||||
func (s *Server) AddAuthorizedKey(peer, newKey string) error {
|
|
||||||
s.mu.Lock()
|
|
||||||
defer s.mu.Unlock()
|
|
||||||
|
|
||||||
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
s.authorizedKeys[peer] = parsedKey
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetstackNet sets the netstack network for userspace networking
|
// SetNetstackNet sets the netstack network for userspace networking
|
||||||
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
func (s *Server) SetNetstackNet(net *netstack.Net) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
@@ -243,34 +244,195 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
|
|||||||
s.wgAddress = addr
|
s.wgAddress = addr
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetSocketFilter configures eBPF socket filtering for the SSH server
|
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
|
||||||
func (s *Server) SetSocketFilter(ifIdx int) {
|
func (s *Server) ensureJWTValidator() error {
|
||||||
|
s.mu.RLock()
|
||||||
|
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||||
|
s.mu.RUnlock()
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
config := s.jwtConfig
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if config == nil {
|
||||||
|
return fmt.Errorf("JWT config not set")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
|
||||||
|
|
||||||
|
validator := jwt.NewValidator(
|
||||||
|
config.Issuer,
|
||||||
|
[]string{config.Audience},
|
||||||
|
config.KeysLocation,
|
||||||
|
true,
|
||||||
|
)
|
||||||
|
|
||||||
|
extractor := jwt.NewClaimsExtractor(
|
||||||
|
jwt.WithAudience(config.Audience),
|
||||||
|
)
|
||||||
|
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
s.ifIdx = ifIdx
|
|
||||||
|
if s.jwtValidator != nil && s.jwtExtractor != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
s.jwtValidator = validator
|
||||||
|
s.jwtExtractor = extractor
|
||||||
|
|
||||||
|
log.Infof("JWT validator initialized successfully")
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
|
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
defer s.mu.RUnlock()
|
jwtValidator := s.jwtValidator
|
||||||
|
jwtConfig := s.jwtConfig
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
for _, allowed := range s.authorizedKeys {
|
if jwtValidator == nil {
|
||||||
if ssh.KeysEqual(allowed, key) {
|
return nil, fmt.Errorf("JWT validator not initialized")
|
||||||
if ctx != nil {
|
}
|
||||||
log.Debugf("SSH key authentication successful for user %s from %s", ctx.User(), ctx.RemoteAddr())
|
|
||||||
|
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
|
||||||
|
if err != nil {
|
||||||
|
if jwtConfig != nil {
|
||||||
|
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
|
||||||
|
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
|
||||||
|
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
}
|
}
|
||||||
|
return nil, fmt.Errorf("validate token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if ctx != nil {
|
if err := s.checkTokenAge(token, jwtConfig); err != nil {
|
||||||
log.Warnf("SSH key authentication failed for user %s from %s: key not authorized (type: %s, fingerprint: %s)",
|
return nil, err
|
||||||
ctx.User(), ctx.RemoteAddr(), key.Type(), cryptossh.FingerprintSHA256(key))
|
|
||||||
}
|
}
|
||||||
return false
|
|
||||||
|
return token, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
|
||||||
|
if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
userID := extractUserID(token)
|
||||||
|
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
iat, ok := claims["iat"].(float64)
|
||||||
|
if !ok {
|
||||||
|
userID := extractUserID(token)
|
||||||
|
return fmt.Errorf("token missing iat claim (user=%s)", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
issuedAt := time.Unix(int64(iat), 0)
|
||||||
|
tokenAge := time.Since(issuedAt)
|
||||||
|
maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
|
||||||
|
if tokenAge > maxAge {
|
||||||
|
userID := getUserIDFromClaims(claims)
|
||||||
|
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
|
||||||
|
s.mu.RLock()
|
||||||
|
jwtExtractor := s.jwtExtractor
|
||||||
|
s.mu.RUnlock()
|
||||||
|
|
||||||
|
if jwtExtractor == nil {
|
||||||
|
userID := extractUserID(token)
|
||||||
|
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := jwtExtractor.ToUserAuth(token)
|
||||||
|
if err != nil {
|
||||||
|
userID := extractUserID(token)
|
||||||
|
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !s.hasSSHAccess(&userAuth) {
|
||||||
|
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &userAuth, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
|
||||||
|
return userAuth.UserId != ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func extractUserID(token *gojwt.Token) string {
|
||||||
|
if token == nil {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
claims, ok := token.Claims.(gojwt.MapClaims)
|
||||||
|
if !ok {
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
return getUserIDFromClaims(claims)
|
||||||
|
}
|
||||||
|
|
||||||
|
func getUserIDFromClaims(claims gojwt.MapClaims) string {
|
||||||
|
if sub, ok := claims["sub"].(string); ok && sub != "" {
|
||||||
|
return sub
|
||||||
|
}
|
||||||
|
if userID, ok := claims["user_id"].(string); ok && userID != "" {
|
||||||
|
return userID
|
||||||
|
}
|
||||||
|
if email, ok := claims["email"].(string); ok && email != "" {
|
||||||
|
return email
|
||||||
|
}
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
|
||||||
|
parts := strings.Split(tokenString, ".")
|
||||||
|
if len(parts) != 3 {
|
||||||
|
return nil, fmt.Errorf("invalid token format")
|
||||||
|
}
|
||||||
|
|
||||||
|
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode payload: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var claims map[string]interface{}
|
||||||
|
if err := json.Unmarshal(payload, &claims); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse claims: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return claims, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
|
||||||
|
if err := s.ensureJWTValidator(); err != nil {
|
||||||
|
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
token, err := s.validateJWTToken(password)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
userAuth, err := s.extractAndValidateUser(token)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// markConnectionActivePortForward marks an SSH connection as having an active port forward
|
|
||||||
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
|
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -286,14 +448,12 @@ func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// connectionCloseHandler cleans up connection state when SSH connections fail/close
|
|
||||||
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
|
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
|
||||||
// We can't extract the SSH connection from net.Conn directly
|
// We can't extract the SSH connection from net.Conn directly
|
||||||
// Connection cleanup will happen during session cleanup or via timeout
|
// Connection cleanup will happen during session cleanup or via timeout
|
||||||
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
|
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// findSessionKeyByContext finds the session key by matching SSH connection context
|
|
||||||
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
||||||
if ctx == nil {
|
if ctx == nil {
|
||||||
return "unknown"
|
return "unknown"
|
||||||
@@ -319,14 +479,13 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
|
|||||||
// Return a temporary key that we'll fix up later
|
// Return a temporary key that we'll fix up later
|
||||||
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
if ctx.User() != "" && ctx.RemoteAddr() != nil {
|
||||||
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
|
||||||
log.Debugf("using temporary session key for port forward tracking: %s", tempKey)
|
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
|
||||||
return tempKey
|
return tempKey
|
||||||
}
|
}
|
||||||
|
|
||||||
return "unknown"
|
return "unknown"
|
||||||
}
|
}
|
||||||
|
|
||||||
// connectionValidator validates incoming connections based on source IP
|
|
||||||
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
||||||
s.mu.RLock()
|
s.mu.RLock()
|
||||||
netbirdNetwork := s.wgAddress.Network
|
netbirdNetwork := s.wgAddress.Network
|
||||||
@@ -340,8 +499,8 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
|||||||
remoteAddr := conn.RemoteAddr()
|
remoteAddr := conn.RemoteAddr()
|
||||||
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
|
||||||
if !ok {
|
if !ok {
|
||||||
log.Debugf("SSH connection from non-TCP address %s allowed", remoteAddr)
|
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
|
||||||
return conn
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
|
||||||
@@ -357,15 +516,14 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if !netbirdNetwork.Contains(remoteIP) {
|
if !netbirdNetwork.Contains(remoteIP) {
|
||||||
log.Warnf("SSH connection rejected from non-NetBird IP %s (allowed range: %s)", remoteIP, netbirdNetwork)
|
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("SSH connection from %s allowed", remoteIP)
|
log.Infof("SSH connection from NetBird peer %s allowed", remoteIP)
|
||||||
return conn
|
return conn
|
||||||
}
|
}
|
||||||
|
|
||||||
// isShutdownError checks if the error is expected during normal shutdown
|
|
||||||
func isShutdownError(err error) bool {
|
func isShutdownError(err error) bool {
|
||||||
if errors.Is(err, net.ErrClosed) {
|
if errors.Is(err, net.ErrClosed) {
|
||||||
return true
|
return true
|
||||||
@@ -379,12 +537,16 @@ func isShutdownError(err error) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// createSSHServer creates and configures the SSH server
|
|
||||||
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
||||||
if err := enableUserSwitching(); err != nil {
|
if err := enableUserSwitching(); err != nil {
|
||||||
log.Warnf("failed to enable user switching: %v", err)
|
log.Warnf("failed to enable user switching: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
|
||||||
|
if s.jwtEnabled {
|
||||||
|
serverVersion += " " + detection.JWTRequiredMarker
|
||||||
|
}
|
||||||
|
|
||||||
server := &ssh.Server{
|
server := &ssh.Server{
|
||||||
Addr: addr.String(),
|
Addr: addr.String(),
|
||||||
Handler: s.sessionHandler,
|
Handler: s.sessionHandler,
|
||||||
@@ -402,6 +564,11 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
|||||||
},
|
},
|
||||||
ConnCallback: s.connectionValidator,
|
ConnCallback: s.connectionValidator,
|
||||||
ConnectionFailedCallback: s.connectionCloseHandler,
|
ConnectionFailedCallback: s.connectionCloseHandler,
|
||||||
|
Version: serverVersion,
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.jwtEnabled {
|
||||||
|
server.PasswordHandler = s.passwordHandler
|
||||||
}
|
}
|
||||||
|
|
||||||
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
|
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
|
||||||
@@ -413,14 +580,12 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
|
|||||||
return server, nil
|
return server, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// storeRemoteForwardListener stores a remote forward listener for cleanup
|
|
||||||
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
|
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
s.remoteForwardListeners[key] = ln
|
s.remoteForwardListeners[key] = ln
|
||||||
}
|
}
|
||||||
|
|
||||||
// removeRemoteForwardListener removes and closes a remote forward listener
|
|
||||||
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
||||||
s.mu.Lock()
|
s.mu.Lock()
|
||||||
defer s.mu.Unlock()
|
defer s.mu.Unlock()
|
||||||
@@ -438,7 +603,6 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// directTCPIPHandler handles direct-tcpip channel requests for local port forwarding with privilege validation
|
|
||||||
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
|
||||||
var payload struct {
|
var payload struct {
|
||||||
Host string
|
Host string
|
||||||
|
|||||||
@@ -22,12 +22,6 @@ func TestServer_RootLoginRestriction(t *testing.T) {
|
|||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Generate client key pair
|
|
||||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
allowRoot bool
|
allowRoot bool
|
||||||
@@ -117,10 +111,12 @@ func TestServer_RootLoginRestriction(t *testing.T) {
|
|||||||
defer cleanup()
|
defer cleanup()
|
||||||
|
|
||||||
// Create server with specific configuration
|
// Create server with specific configuration
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
server.SetAllowRootLogin(tt.allowRoot)
|
server.SetAllowRootLogin(tt.allowRoot)
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Test the userNameLookup method directly
|
// Test the userNameLookup method directly
|
||||||
user, err := server.userNameLookup(tt.username)
|
user, err := server.userNameLookup(tt.username)
|
||||||
@@ -196,7 +192,11 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
// Create server with specific configuration
|
// Create server with specific configuration
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
|
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
|
||||||
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
|
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
|
||||||
|
|
||||||
@@ -234,17 +234,13 @@ func TestServer_PortConflictHandling(t *testing.T) {
|
|||||||
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Generate client key pair
|
|
||||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create server
|
// Create server
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
HostKeyPEM: hostKey,
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
JWT: nil,
|
||||||
require.NoError(t, err)
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
serverAddr := StartTestServer(t, server)
|
serverAddr := StartTestServer(t, server)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -263,7 +259,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
|
|||||||
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel1()
|
defer cancel1()
|
||||||
|
|
||||||
client1, err := sshclient.DialInsecure(ctx1, serverAddr, currentUser.Username)
|
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := client1.Close()
|
err := client1.Close()
|
||||||
@@ -274,7 +272,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
|
|||||||
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
defer cancel2()
|
defer cancel2()
|
||||||
|
|
||||||
client2, err := sshclient.DialInsecure(ctx2, serverAddr, currentUser.Username)
|
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
|
||||||
|
InsecureSkipVerify: true,
|
||||||
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
defer func() {
|
defer func() {
|
||||||
err := client2.Close()
|
err := client2.Close()
|
||||||
|
|||||||
@@ -7,11 +7,9 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"os/user"
|
"os/user"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/gliderlabs/ssh"
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
cryptossh "golang.org/x/crypto/ssh"
|
cryptossh "golang.org/x/crypto/ssh"
|
||||||
@@ -19,82 +17,15 @@ import (
|
|||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestServer_AddAuthorizedKey(t *testing.T) {
|
|
||||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
server := New(key)
|
|
||||||
|
|
||||||
keys := map[string][]byte{}
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
|
||||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
keys[peer] = remotePubKey
|
|
||||||
}
|
|
||||||
|
|
||||||
for peer, remotePubKey := range keys {
|
|
||||||
k, ok := server.authorizedKeys[peer]
|
|
||||||
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
|
|
||||||
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(cryptossh.MarshalAuthorizedKey(k))))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_RemoveAuthorizedKey(t *testing.T) {
|
|
||||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
server := New(key)
|
|
||||||
|
|
||||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
server.RemoveAuthorizedKey("remotePeer")
|
|
||||||
|
|
||||||
_, ok := server.authorizedKeys["remotePeer"]
|
|
||||||
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_PubKeyHandler(t *testing.T) {
|
|
||||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
server := New(key)
|
|
||||||
|
|
||||||
var keys []ssh.PublicKey
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
|
|
||||||
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
err = server.AddAuthorizedKey(peer, string(remotePubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
keys = append(keys, remoteParsedPubKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, key := range keys {
|
|
||||||
accepted := server.publicKeyHandler(nil, key)
|
|
||||||
assert.True(t, accepted, "SSH key should be accepted")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestServer_StartStop(t *testing.T) {
|
func TestServer_StartStop(t *testing.T) {
|
||||||
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
server := New(key)
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: key,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
|
||||||
err = server.Stop()
|
err = server.Stop()
|
||||||
assert.NoError(t, err)
|
assert.NoError(t, err)
|
||||||
@@ -108,15 +39,13 @@ func TestSSHServerIntegration(t *testing.T) {
|
|||||||
// Generate client key pair
|
// Generate client key pair
|
||||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create server with random port
|
// Create server with random port
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
// Add client's public key as authorized
|
JWT: nil,
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
}
|
||||||
require.NoError(t, err)
|
server := New(serverConfig)
|
||||||
|
|
||||||
// Start server in background
|
// Start server in background
|
||||||
serverAddr := "127.0.0.1:0"
|
serverAddr := "127.0.0.1:0"
|
||||||
@@ -212,13 +141,13 @@ func TestSSHServerMultipleConnections(t *testing.T) {
|
|||||||
// Generate client key pair
|
// Generate client key pair
|
||||||
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create server
|
// Create server
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
HostKeyPEM: hostKey,
|
||||||
require.NoError(t, err)
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
serverAddr := "127.0.0.1:0"
|
serverAddr := "127.0.0.1:0"
|
||||||
@@ -324,20 +253,12 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
|||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Generate authorized key
|
// Create server
|
||||||
authorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
serverConfig := &Config{
|
||||||
require.NoError(t, err)
|
HostKeyPEM: hostKey,
|
||||||
authorizedPubKey, err := nbssh.GeneratePublicKey(authorizedPrivKey)
|
JWT: nil,
|
||||||
require.NoError(t, err)
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
// Generate unauthorized key (different from authorized)
|
|
||||||
unauthorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create server with only one authorized key
|
|
||||||
server := New(hostKey)
|
|
||||||
err = server.AddAuthorizedKey("authorized-peer", string(authorizedPubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
serverAddr := "127.0.0.1:0"
|
serverAddr := "127.0.0.1:0"
|
||||||
@@ -377,8 +298,10 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Parse unauthorized private key
|
// Generate a client private key for SSH protocol (server doesn't check it)
|
||||||
unauthorizedSigner, err := cryptossh.ParsePrivateKey(unauthorizedPrivKey)
|
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
|
require.NoError(t, err)
|
||||||
|
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Parse server host key
|
// Parse server host key
|
||||||
@@ -390,17 +313,17 @@ func TestSSHServerNoAuthMode(t *testing.T) {
|
|||||||
currentUser, err := user.Current()
|
currentUser, err := user.Current()
|
||||||
require.NoError(t, err, "Should be able to get current user for test")
|
require.NoError(t, err, "Should be able to get current user for test")
|
||||||
|
|
||||||
// Try to connect with unauthorized key
|
// Try to connect with client key
|
||||||
config := &cryptossh.ClientConfig{
|
config := &cryptossh.ClientConfig{
|
||||||
User: currentUser.Username,
|
User: currentUser.Username,
|
||||||
Auth: []cryptossh.AuthMethod{
|
Auth: []cryptossh.AuthMethod{
|
||||||
cryptossh.PublicKeys(unauthorizedSigner),
|
cryptossh.PublicKeys(clientSigner),
|
||||||
},
|
},
|
||||||
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
|
||||||
Timeout: 3 * time.Second,
|
Timeout: 3 * time.Second,
|
||||||
}
|
}
|
||||||
|
|
||||||
// This should succeed in no-auth mode
|
// This should succeed in no-auth mode (server doesn't verify keys)
|
||||||
conn, err := cryptossh.Dial("tcp", serverAddr, config)
|
conn, err := cryptossh.Dial("tcp", serverAddr, config)
|
||||||
assert.NoError(t, err, "Connection should succeed in no-auth mode")
|
assert.NoError(t, err, "Connection should succeed in no-auth mode")
|
||||||
if conn != nil {
|
if conn != nil {
|
||||||
@@ -412,7 +335,11 @@ func TestSSHServerStartStopCycle(t *testing.T) {
|
|||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
serverAddr := "127.0.0.1:0"
|
serverAddr := "127.0.0.1:0"
|
||||||
|
|
||||||
// Test multiple start/stop cycles
|
// Test multiple start/stop cycles
|
||||||
@@ -485,8 +412,17 @@ func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
|
|||||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
server1 := New(hostKey)
|
serverConfig1 := &Config{
|
||||||
server2 := New(hostKey)
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server1 := New(serverConfig1)
|
||||||
|
|
||||||
|
serverConfig2 := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server2 := New(serverConfig2)
|
||||||
|
|
||||||
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
|
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
|
||||||
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
|
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")
|
||||||
|
|||||||
@@ -35,17 +35,15 @@ func TestSSHServer_SFTPSubsystem(t *testing.T) {
|
|||||||
// Generate client key pair
|
// Generate client key pair
|
||||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create server with SFTP enabled
|
// Create server with SFTP enabled
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
server.SetAllowSFTP(true)
|
server.SetAllowSFTP(true)
|
||||||
server.SetAllowRootLogin(true) // Allow root login for testing
|
server.SetAllowRootLogin(true)
|
||||||
|
|
||||||
// Add client's public key as authorized
|
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
serverAddr := "127.0.0.1:0"
|
serverAddr := "127.0.0.1:0"
|
||||||
@@ -144,17 +142,15 @@ func TestSSHServer_SFTPDisabled(t *testing.T) {
|
|||||||
// Generate client key pair
|
// Generate client key pair
|
||||||
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Create server with SFTP disabled
|
// Create server with SFTP disabled
|
||||||
server := New(hostKey)
|
serverConfig := &Config{
|
||||||
|
HostKeyPEM: hostKey,
|
||||||
|
JWT: nil,
|
||||||
|
}
|
||||||
|
server := New(serverConfig)
|
||||||
server.SetAllowSFTP(false)
|
server.SetAllowSFTP(false)
|
||||||
|
|
||||||
// Add client's public key as authorized
|
|
||||||
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
// Start server
|
// Start server
|
||||||
serverAddr := "127.0.0.1:0"
|
serverAddr := "127.0.0.1:0"
|
||||||
started := make(chan string, 1)
|
started := make(chan string, 1)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@ func StartTestServer(t *testing.T, server *Server) string {
|
|||||||
errChan := make(chan error, 1)
|
errChan := make(chan error, 1)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// Get a free port
|
|
||||||
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
ln, err := net.Listen("tcp", "127.0.0.1:0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errChan <- err
|
errChan <- err
|
||||||
@@ -26,9 +25,12 @@ func StartTestServer(t *testing.T, server *Server) string {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
started <- actualAddr
|
|
||||||
addrPort := netip.MustParseAddrPort(actualAddr)
|
addrPort := netip.MustParseAddrPort(actualAddr)
|
||||||
errChan <- server.Start(context.Background(), addrPort)
|
if err := server.Start(context.Background(), addrPort); err != nil {
|
||||||
|
errChan <- err
|
||||||
|
return
|
||||||
|
}
|
||||||
|
started <- actualAddr
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
select {
|
||||||
|
|||||||
172
client/ssh/testutil/user_helpers.go
Normal file
172
client/ssh/testutil/user_helpers.go
Normal file
@@ -0,0 +1,172 @@
|
|||||||
|
package testutil
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/exec"
|
||||||
|
"os/user"
|
||||||
|
"runtime"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
var testCreatedUsers = make(map[string]bool)
|
||||||
|
var testUsersToCleanup []string
|
||||||
|
|
||||||
|
// GetTestUsername returns an appropriate username for testing
|
||||||
|
func GetTestUsername(t *testing.T) string {
|
||||||
|
if runtime.GOOS == "windows" {
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err, "Should be able to get current user")
|
||||||
|
|
||||||
|
if IsSystemAccount(currentUser.Username) {
|
||||||
|
if IsCI() {
|
||||||
|
if testUser := GetOrCreateTestUser(t); testUser != "" {
|
||||||
|
return testUser
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
if _, err := user.Lookup("Administrator"); err == nil {
|
||||||
|
return "Administrator"
|
||||||
|
}
|
||||||
|
if testUser := GetOrCreateTestUser(t); testUser != "" {
|
||||||
|
return testUser
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return currentUser.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
currentUser, err := user.Current()
|
||||||
|
require.NoError(t, err, "Should be able to get current user")
|
||||||
|
return currentUser.Username
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsCI checks if we're running in a CI environment
|
||||||
|
func IsCI() bool {
|
||||||
|
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
hostname, err := os.Hostname()
|
||||||
|
if err == nil && strings.HasPrefix(hostname, "runner") {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSystemAccount checks if the user is a system account that can't authenticate
|
||||||
|
func IsSystemAccount(username string) bool {
|
||||||
|
systemAccounts := []string{
|
||||||
|
"system",
|
||||||
|
"NT AUTHORITY\\SYSTEM",
|
||||||
|
"NT AUTHORITY\\LOCAL SERVICE",
|
||||||
|
"NT AUTHORITY\\NETWORK SERVICE",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sysAccount := range systemAccounts {
|
||||||
|
if strings.EqualFold(username, sysAccount) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// RegisterTestUserCleanup registers a test user for cleanup
|
||||||
|
func RegisterTestUserCleanup(username string) {
|
||||||
|
if !testCreatedUsers[username] {
|
||||||
|
testCreatedUsers[username] = true
|
||||||
|
testUsersToCleanup = append(testUsersToCleanup, username)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CleanupTestUsers removes all created test users
|
||||||
|
func CleanupTestUsers() {
|
||||||
|
for _, username := range testUsersToCleanup {
|
||||||
|
RemoveWindowsTestUser(username)
|
||||||
|
}
|
||||||
|
testUsersToCleanup = nil
|
||||||
|
testCreatedUsers = make(map[string]bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetOrCreateTestUser creates a test user on Windows if needed
|
||||||
|
func GetOrCreateTestUser(t *testing.T) string {
|
||||||
|
testUsername := "netbird-test-user"
|
||||||
|
|
||||||
|
if _, err := user.Lookup(testUsername); err == nil {
|
||||||
|
return testUsername
|
||||||
|
}
|
||||||
|
|
||||||
|
if CreateWindowsTestUser(t, testUsername) {
|
||||||
|
RegisterTestUserCleanup(testUsername)
|
||||||
|
return testUsername
|
||||||
|
}
|
||||||
|
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveWindowsTestUser removes a local user on Windows using PowerShell
|
||||||
|
func RemoveWindowsTestUser(username string) {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
psCmd := fmt.Sprintf(`
|
||||||
|
try {
|
||||||
|
Remove-LocalUser -Name "%s" -ErrorAction Stop
|
||||||
|
Write-Output "User removed successfully"
|
||||||
|
} catch {
|
||||||
|
if ($_.Exception.Message -like "*cannot be found*") {
|
||||||
|
Write-Output "User not found (already removed)"
|
||||||
|
} else {
|
||||||
|
Write-Error $_.Exception.Message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`, username)
|
||||||
|
|
||||||
|
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
|
||||||
|
} else {
|
||||||
|
log.Printf("Test user %s cleanup result: %s", username, string(output))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// CreateWindowsTestUser creates a local user on Windows using PowerShell
|
||||||
|
func CreateWindowsTestUser(t *testing.T, username string) bool {
|
||||||
|
if runtime.GOOS != "windows" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
psCmd := fmt.Sprintf(`
|
||||||
|
try {
|
||||||
|
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
|
||||||
|
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
|
||||||
|
Add-LocalGroupMember -Group "Users" -Member "%s"
|
||||||
|
Write-Output "User created successfully"
|
||||||
|
} catch {
|
||||||
|
if ($_.Exception.Message -like "*already exists*") {
|
||||||
|
Write-Output "User already exists"
|
||||||
|
} else {
|
||||||
|
Write-Error $_.Exception.Message
|
||||||
|
exit 1
|
||||||
|
}
|
||||||
|
}
|
||||||
|
`, username, username)
|
||||||
|
|
||||||
|
cmd := exec.Command("powershell", "-Command", psCmd)
|
||||||
|
output, err := cmd.CombinedOutput()
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("Test user creation result: %s", string(output))
|
||||||
|
return true
|
||||||
|
}
|
||||||
@@ -77,6 +77,7 @@ type Info struct {
|
|||||||
EnableSSHSFTP bool
|
EnableSSHSFTP bool
|
||||||
EnableSSHLocalPortForwarding bool
|
EnableSSHLocalPortForwarding bool
|
||||||
EnableSSHRemotePortForwarding bool
|
EnableSSHRemotePortForwarding bool
|
||||||
|
DisableSSHAuth bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *Info) SetFlags(
|
func (i *Info) SetFlags(
|
||||||
@@ -85,6 +86,7 @@ func (i *Info) SetFlags(
|
|||||||
disableClientRoutes, disableServerRoutes,
|
disableClientRoutes, disableServerRoutes,
|
||||||
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
|
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
|
||||||
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
|
||||||
|
disableSSHAuth *bool,
|
||||||
) {
|
) {
|
||||||
i.RosenpassEnabled = rosenpassEnabled
|
i.RosenpassEnabled = rosenpassEnabled
|
||||||
i.RosenpassPermissive = rosenpassPermissive
|
i.RosenpassPermissive = rosenpassPermissive
|
||||||
@@ -113,6 +115,9 @@ func (i *Info) SetFlags(
|
|||||||
if enableSSHRemotePortForwarding != nil {
|
if enableSSHRemotePortForwarding != nil {
|
||||||
i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding
|
i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding
|
||||||
}
|
}
|
||||||
|
if disableSSHAuth != nil {
|
||||||
|
i.DisableSSHAuth = *disableSSHAuth
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context
|
||||||
|
|||||||
@@ -270,6 +270,7 @@ type serviceClient struct {
|
|||||||
sEnableSSHSFTP *widget.Check
|
sEnableSSHSFTP *widget.Check
|
||||||
sEnableSSHLocalPortForward *widget.Check
|
sEnableSSHLocalPortForward *widget.Check
|
||||||
sEnableSSHRemotePortForward *widget.Check
|
sEnableSSHRemotePortForward *widget.Check
|
||||||
|
sDisableSSHAuth *widget.Check
|
||||||
|
|
||||||
// observable settings over corresponding iMngURL and iPreSharedKey values.
|
// observable settings over corresponding iMngURL and iPreSharedKey values.
|
||||||
managementURL string
|
managementURL string
|
||||||
@@ -288,6 +289,7 @@ type serviceClient struct {
|
|||||||
enableSSHSFTP bool
|
enableSSHSFTP bool
|
||||||
enableSSHLocalPortForward bool
|
enableSSHLocalPortForward bool
|
||||||
enableSSHRemotePortForward bool
|
enableSSHRemotePortForward bool
|
||||||
|
disableSSHAuth bool
|
||||||
|
|
||||||
connected bool
|
connected bool
|
||||||
update *version.Update
|
update *version.Update
|
||||||
@@ -437,6 +439,7 @@ func (s *serviceClient) showSettingsUI() {
|
|||||||
s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil)
|
s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil)
|
||||||
s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
|
s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
|
||||||
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
|
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
|
||||||
|
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
|
||||||
|
|
||||||
s.wSettings.SetContent(s.getSettingsForm())
|
s.wSettings.SetContent(s.getSettingsForm())
|
||||||
s.wSettings.Resize(fyne.NewSize(600, 400))
|
s.wSettings.Resize(fyne.NewSize(600, 400))
|
||||||
@@ -597,6 +600,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
|
|||||||
req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
|
req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
|
||||||
req.EnableSSHLocalPortForward = &s.sEnableSSHLocalPortForward.Checked
|
req.EnableSSHLocalPortForward = &s.sEnableSSHLocalPortForward.Checked
|
||||||
req.EnableSSHRemotePortForward = &s.sEnableSSHRemotePortForward.Checked
|
req.EnableSSHRemotePortForward = &s.sEnableSSHRemotePortForward.Checked
|
||||||
|
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
|
||||||
|
|
||||||
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
if s.iPreSharedKey.Text != censoredPreSharedKey {
|
||||||
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
|
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
|
||||||
@@ -682,6 +686,7 @@ func (s *serviceClient) getSSHForm() *widget.Form {
|
|||||||
{Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP},
|
{Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP},
|
||||||
{Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
|
{Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
|
||||||
{Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
|
{Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
|
||||||
|
{Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -690,7 +695,8 @@ func (s *serviceClient) hasSSHChanges() bool {
|
|||||||
return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
|
return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
|
||||||
s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
|
s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
|
||||||
s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
|
s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
|
||||||
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked
|
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked ||
|
||||||
|
s.disableSSHAuth != s.sDisableSSHAuth.Checked
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
|
||||||
@@ -1233,6 +1239,9 @@ func (s *serviceClient) getSrvConfig() {
|
|||||||
if cfg.EnableSSHRemotePortForwarding != nil {
|
if cfg.EnableSSHRemotePortForwarding != nil {
|
||||||
s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding
|
s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding
|
||||||
}
|
}
|
||||||
|
if cfg.DisableSSHAuth != nil {
|
||||||
|
s.disableSSHAuth = *cfg.DisableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
if s.showAdvancedSettings {
|
if s.showAdvancedSettings {
|
||||||
s.iMngURL.SetText(s.managementURL)
|
s.iMngURL.SetText(s.managementURL)
|
||||||
@@ -1266,6 +1275,9 @@ func (s *serviceClient) getSrvConfig() {
|
|||||||
if cfg.EnableSSHRemotePortForwarding != nil {
|
if cfg.EnableSSHRemotePortForwarding != nil {
|
||||||
s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding)
|
s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding)
|
||||||
}
|
}
|
||||||
|
if cfg.DisableSSHAuth != nil {
|
||||||
|
s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.mNotifications == nil {
|
if s.mNotifications == nil {
|
||||||
@@ -1348,6 +1360,9 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
|
|||||||
if cfg.EnableSSHRemotePortForwarding {
|
if cfg.EnableSSHRemotePortForwarding {
|
||||||
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
|
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
|
||||||
}
|
}
|
||||||
|
if cfg.DisableSSHAuth {
|
||||||
|
config.DisableSSHAuth = &cfg.DisableSSHAuth
|
||||||
|
}
|
||||||
|
|
||||||
return &config
|
return &config
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -245,6 +245,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
h.client.getSrvConfig()
|
h.client.getSrvConfig()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
netbird "github.com/netbirdio/netbird/client/embed"
|
netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||||
@@ -125,10 +126,15 @@ func createSSHMethod(client *netbird.Client) js.Func {
|
|||||||
username = args[2].String()
|
username = args[2].String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var jwtToken string
|
||||||
|
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
|
||||||
|
jwtToken = args[3].String()
|
||||||
|
}
|
||||||
|
|
||||||
return createPromise(func(resolve, reject js.Value) {
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
sshClient := ssh.NewClient(client)
|
sshClient := ssh.NewClient(client)
|
||||||
|
|
||||||
if err := sshClient.Connect(host, port, username); err != nil {
|
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
|
||||||
reject.Invoke(err.Error())
|
reject.Invoke(err.Error())
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -191,12 +197,43 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
|||||||
}))
|
}))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// createDetectSSHServerMethod creates the SSH server detection method
|
||||||
|
func createDetectSSHServerMethod(client *netbird.Client) js.Func {
|
||||||
|
return js.FuncOf(func(this js.Value, args []js.Value) any {
|
||||||
|
if len(args) < 2 {
|
||||||
|
return js.ValueOf("error: requires host and port")
|
||||||
|
}
|
||||||
|
|
||||||
|
host := args[0].String()
|
||||||
|
port := args[1].Int()
|
||||||
|
|
||||||
|
return createPromise(func(resolve, reject js.Value) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
serverType, err := detectSSHServerType(ctx, client, host, port)
|
||||||
|
if err != nil {
|
||||||
|
reject.Invoke(err.Error())
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resolve.Invoke(js.ValueOf(serverType.RequiresJWT()))
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// detectSSHServerType detects SSH server type using NetBird network connection
|
||||||
|
func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) {
|
||||||
|
return sshdetection.DetectSSHServerType(ctx, client, host, port)
|
||||||
|
}
|
||||||
|
|
||||||
// createClientObject wraps the NetBird client in a JavaScript object
|
// createClientObject wraps the NetBird client in a JavaScript object
|
||||||
func createClientObject(client *netbird.Client) js.Value {
|
func createClientObject(client *netbird.Client) js.Value {
|
||||||
obj := make(map[string]interface{})
|
obj := make(map[string]interface{})
|
||||||
|
|
||||||
obj["start"] = createStartMethod(client)
|
obj["start"] = createStartMethod(client)
|
||||||
obj["stop"] = createStopMethod(client)
|
obj["stop"] = createStopMethod(client)
|
||||||
|
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
|
||||||
obj["createSSHConnection"] = createSSHMethod(client)
|
obj["createSSHConnection"] = createSSHMethod(client)
|
||||||
obj["proxyRequest"] = createProxyRequestMethod(client)
|
obj["proxyRequest"] = createProxyRequestMethod(client)
|
||||||
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
obj["createRDPProxy"] = createRDPProxyMethod(client)
|
||||||
|
|||||||
@@ -13,6 +13,7 @@ import (
|
|||||||
"golang.org/x/crypto/ssh"
|
"golang.org/x/crypto/ssh"
|
||||||
|
|
||||||
netbird "github.com/netbirdio/netbird/client/embed"
|
netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -45,34 +46,19 @@ func NewClient(nbClient *netbird.Client) *Client {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect establishes an SSH connection through NetBird network
|
// Connect establishes an SSH connection through NetBird network
|
||||||
func (c *Client) Connect(host string, port int, username string) error {
|
func (c *Client) Connect(host string, port int, username, jwtToken string) error {
|
||||||
addr := fmt.Sprintf("%s:%d", host, port)
|
addr := fmt.Sprintf("%s:%d", host, port)
|
||||||
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
|
||||||
|
|
||||||
var authMethods []ssh.AuthMethod
|
authMethods, err := c.getAuthMethods(jwtToken)
|
||||||
|
|
||||||
nbConfig, err := c.nbClient.GetConfig()
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("get NetBird config: %w", err)
|
return err
|
||||||
}
|
}
|
||||||
if nbConfig.SSHKey == "" {
|
|
||||||
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("parse NetBird SSH private key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pubKey := signer.PublicKey()
|
|
||||||
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
|
|
||||||
|
|
||||||
authMethods = append(authMethods, ssh.PublicKeys(signer))
|
|
||||||
|
|
||||||
config := &ssh.ClientConfig{
|
config := &ssh.ClientConfig{
|
||||||
User: username,
|
User: username,
|
||||||
Auth: authMethods,
|
Auth: authMethods,
|
||||||
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
|
HostKeyCallback: nbssh.CreateHostKeyCallback(c.nbClient),
|
||||||
Timeout: sshDialTimeout,
|
Timeout: sshDialTimeout,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -96,6 +82,33 @@ func (c *Client) Connect(host string, port int, username string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// getAuthMethods returns SSH authentication methods, preferring JWT if available
|
||||||
|
func (c *Client) getAuthMethods(jwtToken string) ([]ssh.AuthMethod, error) {
|
||||||
|
if jwtToken != "" {
|
||||||
|
logrus.Debugf("SSH: Using JWT password authentication")
|
||||||
|
return []ssh.AuthMethod{ssh.Password(jwtToken)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Debugf("SSH: No JWT token, using public key authentication")
|
||||||
|
|
||||||
|
nbConfig, err := c.nbClient.GetConfig()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get NetBird config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if nbConfig.SSHKey == "" {
|
||||||
|
return nil, fmt.Errorf("no NetBird SSH key available")
|
||||||
|
}
|
||||||
|
|
||||||
|
signer, err := ssh.ParsePrivateKey([]byte(nbConfig.SSHKey))
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse NetBird SSH private key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
logrus.Debugf("SSH: Added public key auth")
|
||||||
|
return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// StartSession starts an SSH session with PTY
|
// StartSession starts an SSH session with PTY
|
||||||
func (c *Client) StartSession(cols, rows int) error {
|
func (c *Client) StartSession(cols, rows int) error {
|
||||||
if c.sshClient == nil {
|
if c.sshClient == nil {
|
||||||
|
|||||||
@@ -1,50 +0,0 @@
|
|||||||
//go:build js
|
|
||||||
|
|
||||||
package ssh
|
|
||||||
|
|
||||||
import (
|
|
||||||
"crypto/x509"
|
|
||||||
"encoding/pem"
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
|
||||||
"golang.org/x/crypto/ssh"
|
|
||||||
)
|
|
||||||
|
|
||||||
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
|
|
||||||
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
|
|
||||||
keyStr := string(keyPEM)
|
|
||||||
if !strings.Contains(keyStr, "-----BEGIN") {
|
|
||||||
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
|
|
||||||
}
|
|
||||||
|
|
||||||
signer, err := ssh.ParsePrivateKey(keyPEM)
|
|
||||||
if err == nil {
|
|
||||||
return signer, nil
|
|
||||||
}
|
|
||||||
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
|
|
||||||
|
|
||||||
block, _ := pem.Decode(keyPEM)
|
|
||||||
if block == nil {
|
|
||||||
keyPreview := string(keyPEM)
|
|
||||||
if len(keyPreview) > 100 {
|
|
||||||
keyPreview = keyPreview[:100]
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
|
|
||||||
if err != nil {
|
|
||||||
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
|
|
||||||
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
|
|
||||||
return ssh.NewSignerFromKey(rsaKey)
|
|
||||||
}
|
|
||||||
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
|
|
||||||
return ssh.NewSignerFromKey(ecKey)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("parse private key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return ssh.NewSignerFromKey(key)
|
|
||||||
}
|
|
||||||
16
go.mod
16
go.mod
@@ -1,6 +1,6 @@
|
|||||||
module github.com/netbirdio/netbird
|
module github.com/netbirdio/netbird
|
||||||
|
|
||||||
go 1.23.0
|
go 1.23.1
|
||||||
|
|
||||||
require (
|
require (
|
||||||
cunicu.li/go-rosenpass v0.4.0
|
cunicu.li/go-rosenpass v0.4.0
|
||||||
@@ -17,8 +17,8 @@ require (
|
|||||||
github.com/spf13/cobra v1.7.0
|
github.com/spf13/cobra v1.7.0
|
||||||
github.com/spf13/pflag v1.0.5
|
github.com/spf13/pflag v1.0.5
|
||||||
github.com/vishvananda/netlink v1.3.0
|
github.com/vishvananda/netlink v1.3.0
|
||||||
golang.org/x/crypto v0.40.0
|
golang.org/x/crypto v0.41.0
|
||||||
golang.org/x/sys v0.34.0
|
golang.org/x/sys v0.35.0
|
||||||
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
|
||||||
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
|
||||||
golang.zx2c4.com/wireguard/windows v0.5.3
|
golang.zx2c4.com/wireguard/windows v0.5.3
|
||||||
@@ -31,6 +31,7 @@ require (
|
|||||||
fyne.io/fyne/v2 v2.5.3
|
fyne.io/fyne/v2 v2.5.3
|
||||||
fyne.io/systray v1.11.0
|
fyne.io/systray v1.11.0
|
||||||
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
|
||||||
|
github.com/awnumar/memguard v0.23.0
|
||||||
github.com/aws/aws-sdk-go-v2 v1.36.3
|
github.com/aws/aws-sdk-go-v2 v1.36.3
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.29.14
|
github.com/aws/aws-sdk-go-v2/config v1.29.14
|
||||||
github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2
|
github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2
|
||||||
@@ -103,11 +104,11 @@ require (
|
|||||||
goauthentik.io/api/v3 v3.2023051.3
|
goauthentik.io/api/v3 v3.2023051.3
|
||||||
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
|
||||||
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
|
||||||
golang.org/x/mod v0.25.0
|
golang.org/x/mod v0.26.0
|
||||||
golang.org/x/net v0.42.0
|
golang.org/x/net v0.42.0
|
||||||
golang.org/x/oauth2 v0.28.0
|
golang.org/x/oauth2 v0.28.0
|
||||||
golang.org/x/sync v0.16.0
|
golang.org/x/sync v0.16.0
|
||||||
golang.org/x/term v0.33.0
|
golang.org/x/term v0.34.0
|
||||||
google.golang.org/api v0.177.0
|
google.golang.org/api v0.177.0
|
||||||
gopkg.in/yaml.v3 v3.0.1
|
gopkg.in/yaml.v3 v3.0.1
|
||||||
gorm.io/driver/mysql v1.5.7
|
gorm.io/driver/mysql v1.5.7
|
||||||
@@ -128,6 +129,7 @@ require (
|
|||||||
github.com/Microsoft/go-winio v0.6.2 // indirect
|
github.com/Microsoft/go-winio v0.6.2 // indirect
|
||||||
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
github.com/Microsoft/hcsshim v0.12.3 // indirect
|
||||||
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
|
||||||
|
github.com/awnumar/memcall v0.4.0 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
|
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
|
||||||
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
|
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
|
||||||
@@ -246,9 +248,9 @@ require (
|
|||||||
go.uber.org/mock v0.4.0 // indirect
|
go.uber.org/mock v0.4.0 // indirect
|
||||||
go.uber.org/multierr v1.11.0 // indirect
|
go.uber.org/multierr v1.11.0 // indirect
|
||||||
golang.org/x/image v0.18.0 // indirect
|
golang.org/x/image v0.18.0 // indirect
|
||||||
golang.org/x/text v0.27.0 // indirect
|
golang.org/x/text v0.28.0 // indirect
|
||||||
golang.org/x/time v0.5.0 // indirect
|
golang.org/x/time v0.5.0 // indirect
|
||||||
golang.org/x/tools v0.34.0 // indirect
|
golang.org/x/tools v0.35.0 // indirect
|
||||||
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
|
||||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
|
||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
|
|||||||
28
go.sum
28
go.sum
@@ -74,6 +74,10 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd
|
|||||||
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
|
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
|
||||||
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
|
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
|
||||||
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
|
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
|
||||||
|
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
|
||||||
|
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
|
||||||
|
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
|
||||||
|
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
|
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
|
||||||
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
|
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
|
||||||
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs=
|
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs=
|
||||||
@@ -778,8 +782,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
|
|||||||
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
|
||||||
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
|
||||||
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
|
||||||
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM=
|
golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
|
||||||
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY=
|
golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
|
||||||
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
|
||||||
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
|
||||||
@@ -828,8 +832,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
|||||||
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
|
||||||
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
|
||||||
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w=
|
golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
|
||||||
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww=
|
golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
|
||||||
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
|
||||||
@@ -985,8 +989,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
|||||||
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||||
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA=
|
golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
|
||||||
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
|
||||||
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
|
||||||
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
|
||||||
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
|
||||||
@@ -999,8 +1003,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
|
|||||||
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
|
||||||
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
|
||||||
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
|
||||||
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg=
|
golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
|
||||||
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0=
|
golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
|
||||||
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
|
||||||
@@ -1017,8 +1021,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
|
|||||||
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
|
||||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||||
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4=
|
golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
|
||||||
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU=
|
golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
|
||||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
@@ -1083,8 +1087,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
|
|||||||
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
|
||||||
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
|
||||||
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
|
||||||
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo=
|
golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
|
||||||
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg=
|
golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
|
||||||
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"net/url"
|
||||||
"os"
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
|
||||||
|
|
||||||
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
"github.com/netbirdio/netbird/management/server/peers/ephemeral"
|
||||||
|
|
||||||
@@ -588,7 +590,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
|
|||||||
// if peer has reached this point then it has logged in
|
// if peer has reached this point then it has logged in
|
||||||
loginResp := &proto.LoginResponse{
|
loginResp := &proto.LoginResponse{
|
||||||
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
|
||||||
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings),
|
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings, s.config),
|
||||||
Checks: toProtocolChecks(ctx, postureChecks),
|
Checks: toProtocolChecks(ctx, postureChecks),
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -703,12 +705,21 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
|
|||||||
return nbConfig
|
return nbConfig
|
||||||
}
|
}
|
||||||
|
|
||||||
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig {
|
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, config *nbconfig.Config) *proto.PeerConfig {
|
||||||
netmask, _ := network.Net.Mask.Size()
|
netmask, _ := network.Net.Mask.Size()
|
||||||
fqdn := peer.FQDN(dnsName)
|
fqdn := peer.FQDN(dnsName)
|
||||||
|
|
||||||
|
sshConfig := &proto.SSHConfig{
|
||||||
|
SshEnabled: peer.SSHEnabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
if peer.SSHEnabled {
|
||||||
|
sshConfig.JwtConfig = buildJWTConfig(config)
|
||||||
|
}
|
||||||
|
|
||||||
return &proto.PeerConfig{
|
return &proto.PeerConfig{
|
||||||
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network
|
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask),
|
||||||
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled},
|
SshConfig: sshConfig,
|
||||||
Fqdn: fqdn,
|
Fqdn: fqdn,
|
||||||
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
|
||||||
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
LazyConnectionEnabled: settings.LazyConnectionEnabled,
|
||||||
@@ -717,7 +728,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
|
|||||||
|
|
||||||
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
|
||||||
response := &proto.SyncResponse{
|
response := &proto.SyncResponse{
|
||||||
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings),
|
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, config),
|
||||||
NetworkMap: &proto.NetworkMap{
|
NetworkMap: &proto.NetworkMap{
|
||||||
Serial: networkMap.Network.CurrentSerial(),
|
Serial: networkMap.Network.CurrentSerial(),
|
||||||
Routes: toProtocolRoutes(networkMap.Routes),
|
Routes: toProtocolRoutes(networkMap.Routes),
|
||||||
@@ -760,6 +771,55 @@ func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.P
|
|||||||
return response
|
return response
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
|
||||||
|
func buildJWTConfig(config *nbconfig.Config) *proto.JWTConfig {
|
||||||
|
if config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if config.HttpConfig == nil || config.HttpConfig.AuthAudience == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var tokenEndpoint string
|
||||||
|
if config.DeviceAuthorizationFlow != nil {
|
||||||
|
tokenEndpoint = config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint
|
||||||
|
}
|
||||||
|
|
||||||
|
issuer := deriveIssuerFromTokenEndpoint(tokenEndpoint)
|
||||||
|
if issuer == "" && config.HttpConfig.AuthIssuer != "" {
|
||||||
|
issuer = config.HttpConfig.AuthIssuer
|
||||||
|
}
|
||||||
|
if issuer == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
keysLocation := config.HttpConfig.AuthKeysLocation
|
||||||
|
if keysLocation == "" {
|
||||||
|
keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json"
|
||||||
|
}
|
||||||
|
|
||||||
|
return &proto.JWTConfig{
|
||||||
|
Issuer: issuer,
|
||||||
|
Audience: config.HttpConfig.AuthAudience,
|
||||||
|
KeysLocation: keysLocation,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// deriveIssuerFromTokenEndpoint extracts the issuer URL from a token endpoint
|
||||||
|
func deriveIssuerFromTokenEndpoint(tokenEndpoint string) string {
|
||||||
|
if tokenEndpoint == "" {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(tokenEndpoint)
|
||||||
|
if err != nil {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
return fmt.Sprintf("%s://%s/", u.Scheme, u.Host)
|
||||||
|
}
|
||||||
|
|
||||||
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
|
||||||
for _, rPeer := range peers {
|
for _, rPeer := range peers {
|
||||||
dst = append(dst, &proto.RemotePeerConfig{
|
dst = append(dst, &proto.RemotePeerConfig{
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -151,6 +151,7 @@ message Flags {
|
|||||||
bool enableSSHSFTP = 12;
|
bool enableSSHSFTP = 12;
|
||||||
bool enableSSHLocalPortForwarding = 13;
|
bool enableSSHLocalPortForwarding = 13;
|
||||||
bool enableSSHRemotePortForwarding = 14;
|
bool enableSSHRemotePortForwarding = 14;
|
||||||
|
bool disableSSHAuth = 15;
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerSystemMeta is machine meta data like OS and version.
|
// PeerSystemMeta is machine meta data like OS and version.
|
||||||
@@ -207,6 +208,8 @@ message NetbirdConfig {
|
|||||||
RelayConfig relay = 4;
|
RelayConfig relay = 4;
|
||||||
|
|
||||||
FlowConfig flow = 5;
|
FlowConfig flow = 5;
|
||||||
|
|
||||||
|
JWTConfig jwt = 6;
|
||||||
}
|
}
|
||||||
|
|
||||||
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
|
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
|
||||||
@@ -245,6 +248,14 @@ message FlowConfig {
|
|||||||
bool dnsCollection = 8;
|
bool dnsCollection = 8;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// JWTConfig represents JWT authentication configuration
|
||||||
|
message JWTConfig {
|
||||||
|
string issuer = 1;
|
||||||
|
string audience = 2;
|
||||||
|
string keysLocation = 3;
|
||||||
|
int64 maxTokenAge = 4;
|
||||||
|
}
|
||||||
|
|
||||||
// ProtectedHostConfig is similar to HostConfig but has additional user and password
|
// ProtectedHostConfig is similar to HostConfig but has additional user and password
|
||||||
// Mostly used for TURN servers
|
// Mostly used for TURN servers
|
||||||
message ProtectedHostConfig {
|
message ProtectedHostConfig {
|
||||||
@@ -340,6 +351,8 @@ message SSHConfig {
|
|||||||
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
|
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
|
||||||
// This property should be ignore if SSHConfig comes from PeerConfig.
|
// This property should be ignore if SSHConfig comes from PeerConfig.
|
||||||
bytes sshPubKey = 2;
|
bytes sshPubKey = 2;
|
||||||
|
|
||||||
|
JWTConfig jwtConfig = 3;
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeviceAuthorizationFlowRequest empty struct for future expansion
|
// DeviceAuthorizationFlowRequest empty struct for future expansion
|
||||||
|
|||||||
Reference in New Issue
Block a user