Add ssh authenatication with jwt (#4550)

This commit is contained in:
Viktor Liu
2025-10-07 23:38:27 +02:00
committed by GitHub
parent 7e0bbaaa3c
commit d9efe4e944
50 changed files with 4429 additions and 2336 deletions

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/client/net"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile

View File

@@ -5,10 +5,12 @@ import (
"errors" "errors"
"flag" "flag"
"fmt" "fmt"
"net"
"os" "os"
"os/signal" "os/signal"
"os/user" "os/user"
"slices" "slices"
"strconv"
"strings" "strings"
"syscall" "syscall"
@@ -16,6 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
sshclient "github.com/netbirdio/netbird/client/ssh/client" sshclient "github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
sshserver "github.com/netbirdio/netbird/client/ssh/server" sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
) )
@@ -29,6 +33,7 @@ const (
enableSSHSFTPFlag = "enable-ssh-sftp" enableSSHSFTPFlag = "enable-ssh-sftp"
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding" enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding" enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
) )
var ( var (
@@ -41,6 +46,7 @@ var (
strictHostKeyChecking bool strictHostKeyChecking bool
knownHostsFile string knownHostsFile string
identityFile string identityFile string
skipCachedToken bool
) )
var ( var (
@@ -49,6 +55,7 @@ var (
enableSSHSFTP bool enableSSHSFTP bool
enableSSHLocalPortForward bool enableSSHLocalPortForward bool
enableSSHRemotePortForward bool enableSSHRemotePortForward bool
disableSSHAuth bool
) )
func init() { func init() {
@@ -57,6 +64,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server") upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server") upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server") upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port") sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc) sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
@@ -64,11 +72,14 @@ func init() {
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)") sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)") sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file") sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport") sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
sshCmd.AddCommand(sshSftpCmd) sshCmd.AddCommand(sshSftpCmd)
sshCmd.AddCommand(sshProxyCmd)
sshCmd.AddCommand(sshDetectCmd)
} }
var sshCmd = &cobra.Command{ var sshCmd = &cobra.Command{
@@ -335,31 +346,51 @@ func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers
} }
// createSSHFlagSet creates and configures the flag set for SSH command parsing // createSSHFlagSet creates and configures the flag set for SSH command parsing
func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string, *string, *string, *string) { // sshFlags contains all SSH-related flags and parameters
type sshFlags struct {
Port int
Username string
Login string
StrictHostKeyChecking bool
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
ConfigPath string
LogLevel string
LocalForwards []string
RemoteForwards []string
Host string
Command string
}
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath) defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel) defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError) fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil) fs.SetOutput(nil)
portFlag := fs.Int("p", sshserver.DefaultSSHPort, "SSH port") flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.Int("port", sshserver.DefaultSSHPort, "SSH port") fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
userFlag := fs.String("u", "", sshUsernameDesc) fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.String("user", "", sshUsernameDesc) fs.String("user", "", sshUsernameDesc)
loginFlag := fs.String("login", "", sshUsernameDesc+" (alias for --user)") fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
strictHostKeyCheckingFlag := fs.Bool("strict-host-key-checking", true, "Enable strict host key checking") fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
knownHostsFlag := fs.String("o", "", "Path to known_hosts file") fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.String("known-hosts", "", "Path to known_hosts file") fs.String("known-hosts", "", "Path to known_hosts file")
identityFlag := fs.String("i", "", "Path to SSH private key file") fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.String("identity", "", "Path to SSH private key file") fs.String("identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
configFlag := fs.String("c", defaultConfigPath, "Netbird config file location") fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.String("config", defaultConfigPath, "Netbird config file location") fs.String("config", defaultConfigPath, "Netbird config file location")
logLevelFlag := fs.String("l", defaultLogLevel, "sets Netbird log level") fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.String("log-level", defaultLogLevel, "sets Netbird log level") fs.String("log-level", defaultLogLevel, "sets Netbird log level")
return fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag return fs, flags
} }
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error { func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
@@ -375,7 +406,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args) filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag := createSSHFlagSet() fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil { if err := fs.Parse(filteredArgs); err != nil {
return parseHostnameAndCommand(filteredArgs) return parseHostnameAndCommand(filteredArgs)
@@ -386,22 +417,23 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
return errors.New(hostArgumentRequired) return errors.New(hostArgumentRequired)
} }
port = *portFlag port = flags.Port
if *userFlag != "" { if flags.Username != "" {
username = *userFlag username = flags.Username
} else if *loginFlag != "" { } else if flags.Login != "" {
username = *loginFlag username = flags.Login
} }
strictHostKeyChecking = *strictHostKeyCheckingFlag strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = *knownHostsFlag knownHostsFile = flags.KnownHostsFile
identityFile = *identityFlag identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken
if *configFlag != getEnvOrDefault("CONFIG", configPath) { if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = *configFlag configPath = flags.ConfigPath
} }
if *logLevelFlag != getEnvOrDefault("LOG_LEVEL", logLevel) { if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = *logLevelFlag logLevel = flags.LogLevel
} }
localForwards = localForwardFlags localForwards = localForwardFlags
@@ -449,30 +481,20 @@ func parseHostnameAndCommand(args []string) error {
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error { func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port) target := fmt.Sprintf("%s:%d", addr, port)
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
var c *sshclient.Client KnownHostsFile: knownHostsFile,
var err error IdentityFile: identityFile,
DaemonAddr: daemonAddr,
if strictHostKeyChecking { SkipCachedToken: skipCachedToken,
c, err = sshclient.DialWithOptions(ctx, target, username, sshclient.DialOptions{ InsecureSkipVerify: !strictHostKeyChecking,
KnownHostsFile: knownHostsFile, })
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
})
} else {
c, err = sshclient.DialInsecure(ctx, target, username)
}
if err != nil { if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target) cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n") cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status\n") cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n") cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n") cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
if strictHostKeyChecking {
cmd.Printf(" 4. Try --strict-host-key-checking=false to bypass host key verification\n")
}
cmd.Printf("\n")
return fmt.Errorf("dial %s: %w", target, err) return fmt.Errorf("dial %s: %w", target, err)
} }
@@ -665,3 +687,65 @@ func normalizeLocalHost(host string) string {
} }
return host return host
} }
var sshProxyCmd = &cobra.Command{
Use: "proxy <host> <port>",
Short: "Internal SSH proxy for native SSH client integration",
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshProxyFn,
}
func sshProxyFn(cmd *cobra.Command, args []string) error {
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port: %s", portStr)
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}
if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err)
}
return nil
}
var sshDetectCmd = &cobra.Command{
Use: "detect <host> <port>",
Short: "Detect if a host is running NetBird SSH",
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshDetectFn,
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
if err := util.InitLog(logLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
os.Exit(serverType.ExitCode())
return nil
}

View File

@@ -360,6 +360,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed { if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
} }
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err) log.Errorf("parse interface name: %v", err)
@@ -460,6 +463,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
} }
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(interfaceNameFlag).Changed { if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil { if err := parseInterfaceName(interfaceName); err != nil {
return nil, err return nil, err
@@ -576,6 +583,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
} }
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(disableAutoConnectFlag).Changed { if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled loginRequest.DisableAutoConnect = &autoConnectDisabled
} }

View File

@@ -18,12 +18,16 @@ import (
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
) )
var ErrClientAlreadyStarted = errors.New("client already started") var (
var ErrClientNotStarted = errors.New("client not started") ErrClientAlreadyStarted = errors.New("client already started")
var ErrConfigNotInitialized = errors.New("config not initialized") ErrClientNotStarted = errors.New("client not started")
ErrEngineNotStarted = errors.New("engine not started")
ErrConfigNotInitialized = errors.New("config not initialized")
)
// Client manages a netbird embedded client instance. // Client manages a netbird embedded client instance.
type Client struct { type Client struct {
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
// Dial dials a network address in the netbird network. // Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) { func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock() engine, err := c.getEngine()
connect := c.connect if err != nil {
if connect == nil { return nil, err
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
} }
nsnet, err := engine.GetNet() nsnet, err := engine.GetNet()
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address) return nsnet.DialContext(ctx, network, address)
} }
// DialContext dials a network address in the netbird network with context
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return c.Dial(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network. // ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled. // Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) { func (c *Client) ListenTCP(address string) (net.Listener, error) {
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
} }
} }
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) { // VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
engine, err := c.getEngine()
if err != nil {
return err
}
storedKey, found := engine.GetPeerSSHKey(peerAddress)
if !found {
return sshcommon.ErrPeerNotFound
}
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available.
func (c *Client) getEngine() (*internal.Engine, error) {
c.mu.Lock() c.mu.Lock()
connect := c.connect connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock() c.mu.Unlock()
if connect == nil {
return nil, ErrClientNotStarted
}
engine := connect.Engine() engine := connect.Engine()
if engine == nil { if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started") return nil, ErrEngineNotStarted
}
return engine, nil
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
engine, err := c.getEngine()
if err != nil {
return nil, netip.Addr{}, err
} }
addr, err := engine.Address() addr, err := engine.Address()

View File

@@ -87,7 +87,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) { func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules := d.squashAcceptRules(networkMap) rules := d.squashAcceptRules(networkMap)
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag // if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
// we have old version of management without rules handling, we should allow all traffic // we have old version of management without rules handling, we should allow all traffic
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty { if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
@@ -350,7 +349,7 @@ func (d *DefaultManager) getPeerRuleID(
// //
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network, // NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy. // but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(networkMap *mgmProto.NetworkMap, ) []*mgmProto.FirewallRule { func (d *DefaultManager) squashAcceptRules(networkMap *mgmProto.NetworkMap) []*mgmProto.FirewallRule {
totalIPs := 0 totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) { for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps { for range p.AllowedIps {

View File

@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer" "github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet" "github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
@@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -437,6 +437,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
EnableSSHSFTP: config.EnableSSHSFTP, EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding, EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding, EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
DisableSSHAuth: config.DisableSSHAuth,
DNSRouteInterval: config.DNSRouteInterval, DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes, DisableClientRoutes: config.DisableClientRoutes,
@@ -527,6 +528,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHSFTP, config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding, config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding, config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
) )
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels) loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil { if err != nil {

View File

@@ -49,6 +49,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto" cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group" semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -117,6 +118,7 @@ type EngineConfig struct {
EnableSSHSFTP *bool EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DNSRouteInterval time.Duration DNSRouteInterval time.Duration
@@ -264,6 +266,7 @@ func NewEngine(
path = mobileDep.StateFilePath path = mobileDep.StateFilePath
} }
engine.stateManager = statemanager.New(path) engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String()) log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine return engine
@@ -676,14 +679,10 @@ func (e *Engine) removeAllPeers() error {
return nil return nil
} }
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server // removePeer closes an existing peer connection and removes a peer
func (e *Engine) removePeer(peerKey string) error { func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey) log.Debugf("removing peer from engine %s", peerKey)
if e.sshServer != nil {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
e.connMgr.RemovePeerConn(peerKey) e.connMgr.RemovePeerConn(peerKey)
err := e.statusRecorder.RemovePeer(peerKey) err := e.statusRecorder.RemovePeer(peerKey)
@@ -859,6 +858,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.EnableSSHSFTP, e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding, e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding, e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
) )
if err := e.mgmClient.SyncMeta(info); err != nil { if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -920,6 +920,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.EnableSSHSFTP, e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding, e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding, e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
) )
err = e.mgmClient.Sync(e.ctx, info, e.handleSync) err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
@@ -1074,24 +1075,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.statusRecorder.FinishPeerListModifications() e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys
if e.sshServer != nil {
for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
if err != nil {
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
}
}
}
}
// update peer SSH host keys in status recorder for daemon API access
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers()) e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
// update SSH client known_hosts with peer host keys for OpenSSH client if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
if err := e.updateSSHKnownHosts(networkMap.GetRemotePeers()); err != nil { log.Warnf("failed to update SSH client config: %v", err)
log.Warnf("failed to update SSH known_hosts: %v", err)
} }
} }
@@ -1480,6 +1467,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.EnableSSHSFTP, e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding, e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding, e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
) )
netMap, err := e.mgmClient.GetNetworkMap(info) netMap, err := e.mgmClient.GetNetworkMap(info)

View File

@@ -5,10 +5,8 @@ import (
"errors" "errors"
"fmt" "fmt"
"net/netip" "net/netip"
"runtime"
"strings" "strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
@@ -21,9 +19,6 @@ import (
type sshServer interface { type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error Start(ctx context.Context, addr netip.AddrPort) error
Stop() error Stop() error
RemoveAuthorizedKey(peer string)
AddAuthorizedKey(peer, newKey string) error
SetSocketFilter(ifIdx int)
} }
func (e *Engine) setupSSHPortRedirection() error { func (e *Engine) setupSSHPortRedirection() error {
@@ -44,22 +39,6 @@ func (e *Engine) setupSSHPortRedirection() error {
return nil return nil
} }
func (e *Engine) setupSSHSocketFilter(server sshServer) error {
if runtime.GOOS != "linux" {
return nil
}
netInterface := e.wgInterface.ToInterface()
if netInterface == nil {
return errors.New("failed to get WireGuard network interface")
}
server.SetSocketFilter(netInterface.Index)
log.Debugf("SSH socket filter configured for interface %s (index: %d)", netInterface.Name, netInterface.Index)
return nil
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error { func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound { if e.config.BlockInbound {
log.Info("SSH server is disabled because inbound connections are blocked") log.Info("SSH server is disabled because inbound connections are blocked")
@@ -83,66 +62,76 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
return nil return nil
} }
return e.startSSHServer() if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
return e.startSSHServer(nil)
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}
return e.startSSHServer(jwtConfig)
}
return errors.New("SSH server requires valid JWT configuration")
} }
// updateSSHKnownHosts updates the SSH known_hosts file with peer host keys for OpenSSH client // updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHKnownHosts(remotePeers []*mgmProto.RemotePeerConfig) error { func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
peerKeys := e.extractPeerHostKeys(remotePeers) peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerKeys) == 0 { if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping known_hosts update") log.Debug("no SSH-enabled peers found, skipping SSH config update")
return nil return nil
} }
if err := e.updateKnownHostsFile(peerKeys); err != nil { configMgr := sshconfig.New()
return err if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
return nil // Don't fail engine startup on SSH config issues
}
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
SSHConfigDir: configMgr.GetSSHConfigDir(),
SSHConfigFile: configMgr.GetSSHConfigFile(),
}); err != nil {
log.Warnf("failed to update SSH config state: %v", err)
} }
e.updateSSHClientConfig(peerKeys)
log.Debugf("updated SSH known_hosts with %d peer host keys", len(peerKeys))
return nil return nil
} }
// extractPeerHostKeys extracts SSH host keys from peer configurations // extractPeerSSHInfo extracts SSH information from peer configurations
func (e *Engine) extractPeerHostKeys(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerHostKey { func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
var peerKeys []sshconfig.PeerHostKey var peerInfo []sshconfig.PeerSSHInfo
for _, peerConfig := range remotePeers { for _, peerConfig := range remotePeers {
peerHostKey, ok := e.parsePeerHostKey(peerConfig) if peerConfig.GetSshConfig() == nil {
if ok { continue
peerKeys = append(peerKeys, peerHostKey)
} }
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
continue
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
})
} }
return peerKeys return peerInfo
}
// parsePeerHostKey parses a single peer's SSH host key configuration
func (e *Engine) parsePeerHostKey(peerConfig *mgmProto.RemotePeerConfig) (sshconfig.PeerHostKey, bool) {
if peerConfig.GetSshConfig() == nil {
return sshconfig.PeerHostKey{}, false
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
return sshconfig.PeerHostKey{}, false
}
hostKey, _, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyBytes)
if err != nil {
log.Warnf("failed to parse SSH public key for peer %s: %v", peerConfig.GetWgPubKey(), err)
return sshconfig.PeerHostKey{}, false
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
return sshconfig.PeerHostKey{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
HostKey: hostKey,
}, true
} }
// extractPeerIP extracts IP address from peer's allowed IPs // extractPeerIP extracts IP address from peer's allowed IPs
@@ -171,25 +160,6 @@ func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
return "" return ""
} }
// updateKnownHostsFile updates the SSH known_hosts file
func (e *Engine) updateKnownHostsFile(peerKeys []sshconfig.PeerHostKey) error {
configMgr := sshconfig.NewManager()
if err := configMgr.UpdatePeerHostKeys(peerKeys); err != nil {
return fmt.Errorf("update peer host keys: %w", err)
}
return nil
}
// updateSSHClientConfig updates SSH client configuration with peer hostnames
func (e *Engine) updateSSHClientConfig(peerKeys []sshconfig.PeerHostKey) {
configMgr := sshconfig.NewManager()
if err := configMgr.SetupSSHClientConfig(peerKeys); err != nil {
log.Warnf("failed to update SSH client config with peer hostnames: %v", err)
} else {
log.Debugf("updated SSH client config with %d peer hostnames", len(peerKeys))
}
}
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access // updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) { func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
for _, peerConfig := range remotePeers { for _, peerConfig := range remotePeers {
@@ -210,30 +180,51 @@ func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig)
log.Debugf("updated peer SSH host keys for daemon API access") log.Debugf("updated peer SSH host keys for daemon API access")
} }
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
e.syncMsgMux.Lock()
statusRecorder := e.statusRecorder
e.syncMsgMux.Unlock()
if statusRecorder == nil {
return nil, false
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if len(peerState.SSHHostKey) > 0 {
return peerState.SSHHostKey, true
}
return nil, false
}
}
return nil, false
}
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown // cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() { func (e *Engine) cleanupSSHConfig() {
configMgr := sshconfig.NewManager() configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil { if err := configMgr.RemoveSSHClientConfig(); err != nil {
log.Warnf("failed to remove SSH client config: %v", err) log.Warnf("failed to remove SSH client config: %v", err)
} else { } else {
log.Debugf("SSH client config cleanup completed") log.Debugf("SSH client config cleanup completed")
} }
if err := configMgr.RemoveKnownHostsFile(); err != nil {
log.Warnf("failed to remove SSH known_hosts: %v", err)
} else {
log.Debugf("SSH known_hosts cleanup completed")
}
} }
// startSSHServer initializes and starts the SSH server with proper configuration. // startSSHServer initializes and starts the SSH server with proper configuration.
func (e *Engine) startSSHServer() error { func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
if e.wgInterface == nil { if e.wgInterface == nil {
return errors.New("wg interface not initialized") return errors.New("wg interface not initialized")
} }
server := sshserver.New(e.config.SSHKey) serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address() wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr) server.SetNetworkValidation(wgAddr)
@@ -259,15 +250,10 @@ func (e *Engine) startSSHServer() error {
log.Warnf("failed to setup SSH port redirection: %v", err) log.Warnf("failed to setup SSH port redirection: %v", err)
} }
if err := e.setupSSHSocketFilter(server); err != nil {
return fmt.Errorf("set socket filter: %w", err)
}
if err := server.Start(e.ctx, listenAddr); err != nil { if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err) return fmt.Errorf("start SSH server: %w", err)
} }
return nil return nil
} }

View File

@@ -281,7 +281,15 @@ func TestEngine_SSH(t *testing.T) {
networkMap = &mgmtProto.NetworkMap{ networkMap = &mgmtProto.NetworkMap{
Serial: 7, Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24", PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}}, SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH}, RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false, RemotePeersIsEmpty: false,
} }

View File

@@ -128,6 +128,7 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.EnableSSHSFTP, config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding, config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding, config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
) )
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels) loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err return serverKey, loginResp, err
@@ -158,6 +159,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
config.EnableSSHSFTP, config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding, config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding, config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
) )
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels) loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil { if err != nil {

View File

@@ -21,9 +21,9 @@ import (
"github.com/netbirdio/netbird/client/internal/ingressgw" "github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay" "github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
relayClient "github.com/netbirdio/netbird/shared/relay/client" relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
) )
const eventQueueSize = 10 const eventQueueSize = 10

View File

@@ -54,6 +54,7 @@ type ConfigInput struct {
EnableSSHSFTP *bool EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
NATExternalIPs []string NATExternalIPs []string
CustomDNSAddress []byte CustomDNSAddress []byte
RosenpassEnabled *bool RosenpassEnabled *bool
@@ -102,6 +103,7 @@ type Config struct {
EnableSSHSFTP *bool EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableClientRoutes bool DisableClientRoutes bool
DisableServerRoutes bool DisableServerRoutes bool
@@ -423,6 +425,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true updated = true
} }
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
if *input.DisableSSHAuth {
log.Infof("disabling SSH authentication")
} else {
log.Infof("enabling SSH authentication")
}
config.DisableSSHAuth = input.DisableSSHAuth
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval { if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)", log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String()) input.DNSRouteInterval.String(), config.DNSRouteInterval.String())

View File

@@ -18,8 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/iface" "github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter" "github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util" "github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
const ( const (

View File

@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter" "github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
// ConnectionListener export internal Listener for mobile // ConnectionListener export internal Listener for mobile

View File

@@ -283,6 +283,7 @@ type LoginRequest struct {
EnableSSHSFTP *bool `protobuf:"varint,34,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"` EnableSSHSFTP *bool `protobuf:"varint,34,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
EnableSSHLocalPortForwarding *bool `protobuf:"varint,35,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"` EnableSSHLocalPortForwarding *bool `protobuf:"varint,35,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"`
EnableSSHRemotePortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"` EnableSSHRemotePortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,37,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -570,6 +571,13 @@ func (x *LoginRequest) GetEnableSSHRemotePortForwarding() bool {
return false return false
} }
func (x *LoginRequest) GetDisableSSHAuth() bool {
if x != nil && x.DisableSSHAuth != nil {
return *x.DisableSSHAuth
}
return false
}
type LoginResponse struct { type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"` NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -1100,6 +1108,7 @@ type GetConfigResponse struct {
EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"` EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"`
EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"` EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"`
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"` EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -1302,6 +1311,13 @@ func (x *GetConfigResponse) GetEnableSSHRemotePortForwarding() bool {
return false return false
} }
func (x *GetConfigResponse) GetDisableSSHAuth() bool {
if x != nil {
return x.DisableSSHAuth
}
return false
}
// PeerState contains the latest state of a peer // PeerState contains the latest state of a peer
type PeerState struct { type PeerState struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
@@ -3781,6 +3797,7 @@ type SetConfigRequest struct {
EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"` EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
EnableSSHLocalPortForward *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForward,proto3,oneof" json:"enableSSHLocalPortForward,omitempty"` EnableSSHLocalPortForward *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForward,proto3,oneof" json:"enableSSHLocalPortForward,omitempty"`
EnableSSHRemotePortForward *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForward,proto3,oneof" json:"enableSSHRemotePortForward,omitempty"` EnableSSHRemotePortForward *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForward,proto3,oneof" json:"enableSSHRemotePortForward,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache sizeCache protoimpl.SizeCache
} }
@@ -4039,6 +4056,13 @@ func (x *SetConfigRequest) GetEnableSSHRemotePortForward() bool {
return false return false
} }
func (x *SetConfigRequest) GetDisableSSHAuth() bool {
if x != nil && x.DisableSSHAuth != nil {
return *x.DisableSSHAuth
}
return false
}
type SetConfigResponse struct { type SetConfigResponse struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields unknownFields protoimpl.UnknownFields
@@ -4774,6 +4798,262 @@ func (x *GetPeerSSHHostKeyResponse) GetFound() bool {
return false return false
} }
// RequestJWTAuthRequest for initiating JWT authentication flow
type RequestJWTAuthRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RequestJWTAuthRequest) Reset() {
*x = RequestJWTAuthRequest{}
mi := &file_daemon_proto_msgTypes[71]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RequestJWTAuthRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RequestJWTAuthRequest) ProtoMessage() {}
func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[71]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead.
func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{71}
}
// RequestJWTAuthResponse contains authentication flow information
type RequestJWTAuthResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// verification URI for user authentication
VerificationURI string `protobuf:"bytes,1,opt,name=verificationURI,proto3" json:"verificationURI,omitempty"`
// complete verification URI (with embedded user code)
VerificationURIComplete string `protobuf:"bytes,2,opt,name=verificationURIComplete,proto3" json:"verificationURIComplete,omitempty"`
// user code to enter on verification URI
UserCode string `protobuf:"bytes,3,opt,name=userCode,proto3" json:"userCode,omitempty"`
// device code for polling
DeviceCode string `protobuf:"bytes,4,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
// expiration time in seconds
ExpiresIn int64 `protobuf:"varint,5,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
// if a cached token is available, it will be returned here
CachedToken string `protobuf:"bytes,6,opt,name=cachedToken,proto3" json:"cachedToken,omitempty"`
// maximum age of JWT tokens in seconds (from management server)
MaxTokenAge int64 `protobuf:"varint,7,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RequestJWTAuthResponse) Reset() {
*x = RequestJWTAuthResponse{}
mi := &file_daemon_proto_msgTypes[72]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RequestJWTAuthResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RequestJWTAuthResponse) ProtoMessage() {}
func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[72]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead.
func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{72}
}
func (x *RequestJWTAuthResponse) GetVerificationURI() string {
if x != nil {
return x.VerificationURI
}
return ""
}
func (x *RequestJWTAuthResponse) GetVerificationURIComplete() string {
if x != nil {
return x.VerificationURIComplete
}
return ""
}
func (x *RequestJWTAuthResponse) GetUserCode() string {
if x != nil {
return x.UserCode
}
return ""
}
func (x *RequestJWTAuthResponse) GetDeviceCode() string {
if x != nil {
return x.DeviceCode
}
return ""
}
func (x *RequestJWTAuthResponse) GetExpiresIn() int64 {
if x != nil {
return x.ExpiresIn
}
return 0
}
func (x *RequestJWTAuthResponse) GetCachedToken() string {
if x != nil {
return x.CachedToken
}
return ""
}
func (x *RequestJWTAuthResponse) GetMaxTokenAge() int64 {
if x != nil {
return x.MaxTokenAge
}
return 0
}
// WaitJWTTokenRequest for waiting for authentication completion
type WaitJWTTokenRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
// device code from RequestJWTAuthResponse
DeviceCode string `protobuf:"bytes,1,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
// user code for verification
UserCode string `protobuf:"bytes,2,opt,name=userCode,proto3" json:"userCode,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *WaitJWTTokenRequest) Reset() {
*x = WaitJWTTokenRequest{}
mi := &file_daemon_proto_msgTypes[73]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *WaitJWTTokenRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*WaitJWTTokenRequest) ProtoMessage() {}
func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[73]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead.
func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{73}
}
func (x *WaitJWTTokenRequest) GetDeviceCode() string {
if x != nil {
return x.DeviceCode
}
return ""
}
func (x *WaitJWTTokenRequest) GetUserCode() string {
if x != nil {
return x.UserCode
}
return ""
}
// WaitJWTTokenResponse contains the JWT token after authentication
type WaitJWTTokenResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// JWT token (access token or ID token)
Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"`
// token type (e.g., "Bearer")
TokenType string `protobuf:"bytes,2,opt,name=tokenType,proto3" json:"tokenType,omitempty"`
// expiration time in seconds
ExpiresIn int64 `protobuf:"varint,3,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *WaitJWTTokenResponse) Reset() {
*x = WaitJWTTokenResponse{}
mi := &file_daemon_proto_msgTypes[74]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *WaitJWTTokenResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*WaitJWTTokenResponse) ProtoMessage() {}
func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[74]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead.
func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{74}
}
func (x *WaitJWTTokenResponse) GetToken() string {
if x != nil {
return x.Token
}
return ""
}
func (x *WaitJWTTokenResponse) GetTokenType() string {
if x != nil {
return x.TokenType
}
return ""
}
func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
if x != nil {
return x.ExpiresIn
}
return 0
}
type PortInfo_Range struct { type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"` state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"` Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -4784,7 +5064,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() { func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{} *x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[72] mi := &file_daemon_proto_msgTypes[76]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi) ms.StoreMessageInfo(mi)
} }
@@ -4796,7 +5076,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {} func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message { func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[72] mi := &file_daemon_proto_msgTypes[76]
if x != nil { if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil { if ms.LoadMessageInfo() == nil {
@@ -4831,7 +5111,7 @@ var File_daemon_proto protoreflect.FileDescriptor
const file_daemon_proto_rawDesc = "" + const file_daemon_proto_rawDesc = "" +
"\n" + "\n" +
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" + "\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
"\fEmptyRequest\"\x94\x11\n" + "\fEmptyRequest\"\xd4\x11\n" +
"\fLoginRequest\x12\x1a\n" + "\fLoginRequest\x12\x1a\n" +
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" + "\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" + "\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
@@ -4872,7 +5152,8 @@ const file_daemon_proto_rawDesc = "" +
"\renableSSHRoot\x18! \x01(\bH\x14R\renableSSHRoot\x88\x01\x01\x12)\n" + "\renableSSHRoot\x18! \x01(\bH\x14R\renableSSHRoot\x88\x01\x01\x12)\n" +
"\renableSSHSFTP\x18\" \x01(\bH\x15R\renableSSHSFTP\x88\x01\x01\x12G\n" + "\renableSSHSFTP\x18\" \x01(\bH\x15R\renableSSHSFTP\x88\x01\x01\x12G\n" +
"\x1cenableSSHLocalPortForwarding\x18# \x01(\bH\x16R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" + "\x1cenableSSHLocalPortForwarding\x18# \x01(\bH\x16R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" +
"\x1denableSSHRemotePortForwarding\x18$ \x01(\bH\x17R\x1denableSSHRemotePortForwarding\x88\x01\x01B\x13\n" + "\x1denableSSHRemotePortForwarding\x18$ \x01(\bH\x17R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" +
"\x0edisableSSHAuth\x18% \x01(\bH\x18R\x0edisableSSHAuth\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" + "\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" + "\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" + "\x0e_wireguardPortB\x17\n" +
@@ -4896,7 +5177,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0e_enableSSHRootB\x10\n" + "\x0e_enableSSHRootB\x10\n" +
"\x0e_enableSSHSFTPB\x1f\n" + "\x0e_enableSSHSFTPB\x1f\n" +
"\x1d_enableSSHLocalPortForwardingB \n" + "\x1d_enableSSHLocalPortForwardingB \n" +
"\x1e_enableSSHRemotePortForwarding\"\xb5\x01\n" + "\x1e_enableSSHRemotePortForwardingB\x11\n" +
"\x0f_disableSSHAuth\"\xb5\x01\n" +
"\rLoginResponse\x12$\n" + "\rLoginResponse\x12$\n" +
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" + "\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" + "\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +
@@ -4929,7 +5211,7 @@ const file_daemon_proto_rawDesc = "" +
"\fDownResponse\"P\n" + "\fDownResponse\"P\n" +
"\x10GetConfigRequest\x12 \n" + "\x10GetConfigRequest\x12 \n" +
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" + "\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
"\busername\x18\x02 \x01(\tR\busername\"\x8b\b\n" + "\busername\x18\x02 \x01(\tR\busername\"\xb3\b\n" +
"\x11GetConfigResponse\x12$\n" + "\x11GetConfigResponse\x12$\n" +
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" + "\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
"\n" + "\n" +
@@ -4958,7 +5240,8 @@ const file_daemon_proto_rawDesc = "" +
"\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" + "\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" +
"\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" + "\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" +
"\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" + "\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" +
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\"\xfe\x05\n" + "\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" +
"\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\"\xfe\x05\n" +
"\tPeerState\x12\x0e\n" + "\tPeerState\x12\x0e\n" +
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" + "\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" + "\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
@@ -5161,7 +5444,7 @@ const file_daemon_proto_rawDesc = "" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" + "\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" + "\f_profileNameB\v\n" +
"\t_username\"\x17\n" + "\t_username\"\x17\n" +
"\x15SwitchProfileResponse\"\xcd\x0f\n" + "\x15SwitchProfileResponse\"\x8d\x10\n" +
"\x10SetConfigRequest\x12\x1a\n" + "\x10SetConfigRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" + "\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" +
@@ -5198,7 +5481,8 @@ const file_daemon_proto_rawDesc = "" +
"\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" + "\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" +
"\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12A\n" + "\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12A\n" +
"\x19enableSSHLocalPortForward\x18\x1f \x01(\bH\x14R\x19enableSSHLocalPortForward\x88\x01\x01\x12C\n" + "\x19enableSSHLocalPortForward\x18\x1f \x01(\bH\x14R\x19enableSSHLocalPortForward\x88\x01\x01\x12C\n" +
"\x1aenableSSHRemotePortForward\x18 \x01(\bH\x15R\x1aenableSSHRemotePortForward\x88\x01\x01B\x13\n" + "\x1aenableSSHRemotePortForward\x18 \x01(\bH\x15R\x1aenableSSHRemotePortForward\x88\x01\x01\x12+\n" +
"\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" + "\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" + "\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" + "\x0e_wireguardPortB\x17\n" +
@@ -5220,7 +5504,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0e_enableSSHRootB\x10\n" + "\x0e_enableSSHRootB\x10\n" +
"\x0e_enableSSHSFTPB\x1c\n" + "\x0e_enableSSHSFTPB\x1c\n" +
"\x1a_enableSSHLocalPortForwardB\x1d\n" + "\x1a_enableSSHLocalPortForwardB\x1d\n" +
"\x1b_enableSSHRemotePortForward\"\x13\n" + "\x1b_enableSSHRemotePortForwardB\x11\n" +
"\x0f_disableSSHAuth\"\x13\n" +
"\x11SetConfigResponse\"Q\n" + "\x11SetConfigResponse\"Q\n" +
"\x11AddProfileRequest\x12\x1a\n" + "\x11AddProfileRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" + "\busername\x18\x01 \x01(\tR\busername\x12 \n" +
@@ -5259,7 +5544,27 @@ const file_daemon_proto_rawDesc = "" +
"sshHostKey\x12\x16\n" + "sshHostKey\x12\x16\n" +
"\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" + "\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" +
"\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" + "\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" +
"\x05found\x18\x04 \x01(\bR\x05found*b\n" + "\x05found\x18\x04 \x01(\bR\x05found\"\x17\n" +
"\x15RequestJWTAuthRequest\"\x9a\x02\n" +
"\x16RequestJWTAuthResponse\x12(\n" +
"\x0fverificationURI\x18\x01 \x01(\tR\x0fverificationURI\x128\n" +
"\x17verificationURIComplete\x18\x02 \x01(\tR\x17verificationURIComplete\x12\x1a\n" +
"\buserCode\x18\x03 \x01(\tR\buserCode\x12\x1e\n" +
"\n" +
"deviceCode\x18\x04 \x01(\tR\n" +
"deviceCode\x12\x1c\n" +
"\texpiresIn\x18\x05 \x01(\x03R\texpiresIn\x12 \n" +
"\vcachedToken\x18\x06 \x01(\tR\vcachedToken\x12 \n" +
"\vmaxTokenAge\x18\a \x01(\x03R\vmaxTokenAge\"Q\n" +
"\x13WaitJWTTokenRequest\x12\x1e\n" +
"\n" +
"deviceCode\x18\x01 \x01(\tR\n" +
"deviceCode\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\"h\n" +
"\x14WaitJWTTokenResponse\x12\x14\n" +
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" +
"\bLogLevel\x12\v\n" + "\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" + "\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" + "\x05PANIC\x10\x01\x12\t\n" +
@@ -5268,7 +5573,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" + "\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" + "\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" + "\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xeb\x10\n" + "\x05TRACE\x10\a2\x8b\x12\n" +
"\rDaemonService\x126\n" + "\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" + "\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" + "\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -5301,7 +5606,9 @@ const file_daemon_proto_rawDesc = "" +
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" + "\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" +
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" + "\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" +
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" + "\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00B\bZ\x06/protob\x06proto3" "\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00B\bZ\x06/protob\x06proto3"
var ( var (
file_daemon_proto_rawDescOnce sync.Once file_daemon_proto_rawDescOnce sync.Once
@@ -5316,7 +5623,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
} }
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3) var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 74) var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 78)
var file_daemon_proto_goTypes = []any{ var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel (LogLevel)(0), // 0: daemon.LogLevel
(SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity (SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
@@ -5392,18 +5699,22 @@ var file_daemon_proto_goTypes = []any{
(*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse (*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse
(*GetPeerSSHHostKeyRequest)(nil), // 72: daemon.GetPeerSSHHostKeyRequest (*GetPeerSSHHostKeyRequest)(nil), // 72: daemon.GetPeerSSHHostKeyRequest
(*GetPeerSSHHostKeyResponse)(nil), // 73: daemon.GetPeerSSHHostKeyResponse (*GetPeerSSHHostKeyResponse)(nil), // 73: daemon.GetPeerSSHHostKeyResponse
nil, // 74: daemon.Network.ResolvedIPsEntry (*RequestJWTAuthRequest)(nil), // 74: daemon.RequestJWTAuthRequest
(*PortInfo_Range)(nil), // 75: daemon.PortInfo.Range (*RequestJWTAuthResponse)(nil), // 75: daemon.RequestJWTAuthResponse
nil, // 76: daemon.SystemEvent.MetadataEntry (*WaitJWTTokenRequest)(nil), // 76: daemon.WaitJWTTokenRequest
(*durationpb.Duration)(nil), // 77: google.protobuf.Duration (*WaitJWTTokenResponse)(nil), // 77: daemon.WaitJWTTokenResponse
(*timestamppb.Timestamp)(nil), // 78: google.protobuf.Timestamp nil, // 78: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 79: daemon.PortInfo.Range
nil, // 80: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 81: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 82: google.protobuf.Timestamp
} }
var file_daemon_proto_depIdxs = []int32{ var file_daemon_proto_depIdxs = []int32{
77, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 81, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus 22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
78, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp 82, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
78, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp 82, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
77, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration 81, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState 19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState 18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState 17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -5412,8 +5723,8 @@ var file_daemon_proto_depIdxs = []int32{
21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState 21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent 52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network 28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
74, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry 78, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
75, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range 79, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo 29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo 29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule 30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -5424,10 +5735,10 @@ var file_daemon_proto_depIdxs = []int32{
49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage 49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity 1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category 2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
78, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp 82, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
76, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry 80, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent 52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
77, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration 81, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile 65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList 27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest 4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -5459,37 +5770,41 @@ var file_daemon_proto_depIdxs = []int32{
68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest 68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest 70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
72, // 59: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest 72, // 59: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
5, // 60: daemon.DaemonService.Login:output_type -> daemon.LoginResponse 74, // 60: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
7, // 61: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse 76, // 61: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
9, // 62: daemon.DaemonService.Up:output_type -> daemon.UpResponse 5, // 62: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
11, // 63: daemon.DaemonService.Status:output_type -> daemon.StatusResponse 7, // 63: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
13, // 64: daemon.DaemonService.Down:output_type -> daemon.DownResponse 9, // 64: daemon.DaemonService.Up:output_type -> daemon.UpResponse
15, // 65: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse 11, // 65: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
24, // 66: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse 13, // 66: daemon.DaemonService.Down:output_type -> daemon.DownResponse
26, // 67: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse 15, // 67: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
26, // 68: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse 24, // 68: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
31, // 69: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse 26, // 69: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
33, // 70: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse 26, // 70: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
35, // 71: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse 31, // 71: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
37, // 72: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse 33, // 72: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
40, // 73: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse 35, // 73: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
42, // 74: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse 37, // 74: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
44, // 75: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse 40, // 75: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
46, // 76: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse 42, // 76: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
50, // 77: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse 44, // 77: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
52, // 78: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent 46, // 78: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
54, // 79: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse 50, // 79: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
56, // 80: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse 52, // 80: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
58, // 81: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse 54, // 81: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
60, // 82: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse 56, // 82: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
62, // 83: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse 58, // 83: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
64, // 84: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse 60, // 84: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
67, // 85: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse 62, // 85: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
69, // 86: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse 64, // 86: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
71, // 87: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse 67, // 87: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
73, // 88: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse 69, // 88: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
60, // [60:89] is the sub-list for method output_type 71, // 89: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
31, // [31:60] is the sub-list for method input_type 73, // 90: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
75, // 91: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
77, // 92: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
62, // [62:93] is the sub-list for method output_type
31, // [31:62] is the sub-list for method input_type
31, // [31:31] is the sub-list for extension type_name 31, // [31:31] is the sub-list for extension type_name
31, // [31:31] is the sub-list for extension extendee 31, // [31:31] is the sub-list for extension extendee
0, // [0:31] is the sub-list for field type_name 0, // [0:31] is the sub-list for field type_name
@@ -5518,7 +5833,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(), GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)), RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 3, NumEnums: 3,
NumMessages: 74, NumMessages: 78,
NumExtensions: 0, NumExtensions: 0,
NumServices: 1, NumServices: 1,
}, },

View File

@@ -87,6 +87,12 @@ service DaemonService {
// GetPeerSSHHostKey retrieves SSH host key for a specific peer // GetPeerSSHHostKey retrieves SSH host key for a specific peer
rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {} rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
// RequestJWTAuth initiates JWT authentication flow for SSH
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
// WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
} }
@@ -166,6 +172,7 @@ message LoginRequest {
optional bool enableSSHSFTP = 34; optional bool enableSSHSFTP = 34;
optional bool enableSSHLocalPortForwarding = 35; optional bool enableSSHLocalPortForwarding = 35;
optional bool enableSSHRemotePortForwarding = 36; optional bool enableSSHRemotePortForwarding = 36;
optional bool disableSSHAuth = 37;
} }
message LoginResponse { message LoginResponse {
@@ -268,6 +275,8 @@ message GetConfigResponse {
bool enableSSHLocalPortForwarding = 22; bool enableSSHLocalPortForwarding = 22;
bool enableSSHRemotePortForwarding = 23; bool enableSSHRemotePortForwarding = 23;
bool disableSSHAuth = 25;
} }
// PeerState contains the latest state of a peer // PeerState contains the latest state of a peer
@@ -612,6 +621,7 @@ message SetConfigRequest {
optional bool enableSSHSFTP = 30; optional bool enableSSHSFTP = 30;
optional bool enableSSHLocalPortForward = 31; optional bool enableSSHLocalPortForward = 31;
optional bool enableSSHRemotePortForward = 32; optional bool enableSSHRemotePortForward = 32;
optional bool disableSSHAuth = 33;
} }
message SetConfigResponse{} message SetConfigResponse{}
@@ -681,3 +691,43 @@ message GetPeerSSHHostKeyResponse {
// indicates if the SSH host key was found // indicates if the SSH host key was found
bool found = 4; bool found = 4;
} }
// RequestJWTAuthRequest for initiating JWT authentication flow
message RequestJWTAuthRequest {
}
// RequestJWTAuthResponse contains authentication flow information
message RequestJWTAuthResponse {
// verification URI for user authentication
string verificationURI = 1;
// complete verification URI (with embedded user code)
string verificationURIComplete = 2;
// user code to enter on verification URI
string userCode = 3;
// device code for polling
string deviceCode = 4;
// expiration time in seconds
int64 expiresIn = 5;
// if a cached token is available, it will be returned here
string cachedToken = 6;
// maximum age of JWT tokens in seconds (from management server)
int64 maxTokenAge = 7;
}
// WaitJWTTokenRequest for waiting for authentication completion
message WaitJWTTokenRequest {
// device code from RequestJWTAuthResponse
string deviceCode = 1;
// user code for verification
string userCode = 2;
}
// WaitJWTTokenResponse contains the JWT token after authentication
message WaitJWTTokenResponse {
// JWT token (access token or ID token)
string token = 1;
// token type (e.g., "Bearer")
string tokenType = 2;
// expiration time in seconds
int64 expiresIn = 3;
}

View File

@@ -66,6 +66,10 @@ type DaemonServiceClient interface {
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error) GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
// GetPeerSSHHostKey retrieves SSH host key for a specific peer // GetPeerSSHHostKey retrieves SSH host key for a specific peer
GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error) GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
// RequestJWTAuth initiates JWT authentication flow for SSH
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
} }
type daemonServiceClient struct { type daemonServiceClient struct {
@@ -360,6 +364,24 @@ func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeer
return out, nil return out, nil
} }
func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) {
out := new(RequestJWTAuthResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) {
out := new(WaitJWTTokenResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service. // DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer // All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility // for forward compatibility
@@ -412,6 +434,10 @@ type DaemonServiceServer interface {
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error) GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
// GetPeerSSHHostKey retrieves SSH host key for a specific peer // GetPeerSSHHostKey retrieves SSH host key for a specific peer
GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
// RequestJWTAuth initiates JWT authentication flow for SSH
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
mustEmbedUnimplementedDaemonServiceServer() mustEmbedUnimplementedDaemonServiceServer()
} }
@@ -506,6 +532,12 @@ func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeature
func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) { func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented") return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
} }
func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented")
}
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {} func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service. // UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1044,6 +1076,42 @@ func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Conte
return interceptor(ctx, in, info, handler) return interceptor(ctx, in, info, handler)
} }
func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RequestJWTAuthRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/RequestJWTAuth",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(WaitJWTTokenRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).WaitJWTToken(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/WaitJWTToken",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service. // DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService, // It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy) // and not to be introspected or modified (even as a copy)
@@ -1163,6 +1231,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetPeerSSHHostKey", MethodName: "GetPeerSSHHostKey",
Handler: _DaemonService_GetPeerSSHHostKey_Handler, Handler: _DaemonService_GetPeerSSHHostKey_Handler,
}, },
{
MethodName: "RequestJWTAuth",
Handler: _DaemonService_RequestJWTAuth_Handler,
},
{
MethodName: "WaitJWTToken",
Handler: _DaemonService_WaitJWTToken_Handler,
},
}, },
Streams: []grpc.StreamDesc{ Streams: []grpc.StreamDesc{
{ {

View File

@@ -0,0 +1,73 @@
package server
import (
"sync"
"time"
"github.com/awnumar/memguard"
log "github.com/sirupsen/logrus"
)
type jwtCache struct {
mu sync.RWMutex
enclave *memguard.Enclave
expiresAt time.Time
timer *time.Timer
maxTokenSize int
}
func newJWTCache() *jwtCache {
return &jwtCache{
maxTokenSize: 8192,
}
}
func (c *jwtCache) store(token string, maxAge time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanup()
if c.timer != nil {
c.timer.Stop()
}
tokenBytes := []byte(token)
c.enclave = memguard.NewEnclave(tokenBytes)
c.expiresAt = time.Now().Add(maxAge)
c.timer = time.AfterFunc(maxAge, func() {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanup()
c.timer = nil
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
})
}
func (c *jwtCache) get() (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.enclave == nil || time.Now().After(c.expiresAt) {
return "", false
}
buffer, err := c.enclave.Open()
if err != nil {
log.Debugf("Failed to open JWT token enclave: %v", err)
return "", false
}
defer buffer.Destroy()
token := string(buffer.Bytes())
return token, true
}
// cleanup destroys the secure enclave, must be called with lock held
func (c *jwtCache) cleanup() {
if c.enclave != nil {
c.enclave = nil
}
}

View File

@@ -11,8 +11,8 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
) )
type selectRoute struct { type selectRoute struct {

View File

@@ -46,6 +46,9 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7 defaultRetryMultiplier = 1.7
// JWT token cache TTL for the client daemon
defaultJWTCacheTTL = 5 * time.Minute
errRestoreResidualState = "failed to restore residual state: %v" errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled" errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled" errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
@@ -81,6 +84,8 @@ type Server struct {
profileManager *profilemanager.ServiceManager profileManager *profilemanager.ServiceManager
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool updateSettingsDisabled bool
jwtCache *jwtCache
} }
type oauthAuthFlow struct { type oauthAuthFlow struct {
@@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile), profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled, profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled, updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(),
} }
} }
@@ -370,6 +376,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.EnableSSHSFTP = msg.EnableSSHSFTP config.EnableSSHSFTP = msg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.Mtu != nil { if msg.Mtu != nil {
mtu := uint16(*msg.Mtu) mtu := uint16(*msg.Mtu)
@@ -486,7 +495,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return nil, err return nil, err
} }
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) { if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) { if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous oauth flow info") log.Debugf("using previous oauth flow info")
return &proto.LoginResponse{ return &proto.LoginResponse{
@@ -503,7 +512,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
} }
} }
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO()) authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil { if err != nil {
log.Errorf("getting a request OAuth flow failed: %v", err) log.Errorf("getting a request OAuth flow failed: %v", err)
return nil, err return nil, err
@@ -1077,28 +1086,41 @@ func (s *Server) GetPeerSSHHostKey(
} }
s.mutex.Lock() s.mutex.Lock()
defer s.mutex.Unlock() connectClient := s.connectClient
statusRecorder := s.statusRecorder
s.mutex.Unlock()
response := &proto.GetPeerSSHHostKeyResponse{ if connectClient == nil {
Found: false, return nil, errors.New("client not initialized")
} }
if s.statusRecorder == nil { engine := connectClient.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
peerAddress := req.GetPeerAddress()
hostKey, found := engine.GetPeerSSHKey(peerAddress)
response := &proto.GetPeerSSHHostKeyResponse{
Found: found,
}
if !found {
return response, nil return response, nil
} }
fullStatus := s.statusRecorder.GetFullStatus() response.SshHostKey = hostKey
peerAddress := req.GetPeerAddress()
// Search for peer by IP or FQDN if statusRecorder == nil {
return response, nil
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers { for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress { if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if len(peerState.SSHHostKey) > 0 { response.PeerIP = peerState.IP
response.SshHostKey = peerState.SSHHostKey response.PeerFQDN = peerState.FQDN
response.PeerIP = peerState.IP
response.PeerFQDN = peerState.FQDN
response.Found = true
}
break break
} }
} }
@@ -1106,6 +1128,137 @@ func (s *Server) GetPeerSSHHostKey(
return response, nil return response, nil
} }
// getJWTCacheTTL returns the JWT cache TTL from environment variable or default
// NB_SSH_JWT_CACHE_TTL=0 disables caching
// NB_SSH_JWT_CACHE_TTL=<seconds> sets custom cache TTL
func getJWTCacheTTL() time.Duration {
envValue := os.Getenv("NB_SSH_JWT_CACHE_TTL")
if envValue == "" {
return defaultJWTCacheTTL
}
seconds, err := strconv.Atoi(envValue)
if err != nil {
log.Warnf("invalid NB_SSH_JWT_CACHE_TTL value %s, using default: %v", envValue, defaultJWTCacheTTL)
return defaultJWTCacheTTL
}
if seconds == 0 {
log.Info("SSH JWT cache disabled via NB_SSH_JWT_CACHE_TTL=0")
return 0
}
ttl := time.Duration(seconds) * time.Second
log.Infof("SSH JWT cache TTL set to %v via NB_SSH_JWT_CACHE_TTL", ttl)
return ttl
}
// RequestJWTAuth initiates JWT authentication flow for SSH
func (s *Server) RequestJWTAuth(
ctx context.Context,
_ *proto.RequestJWTAuthRequest,
) (*proto.RequestJWTAuthResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
}
jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 {
if cachedToken, found := s.jwtCache.get(); found {
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
return &proto.RequestJWTAuthResponse{
CachedToken: cachedToken,
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
}
isDesktop := isUnixRunningDesktop()
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
}
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
}
s.mutex.Lock()
s.oauthAuthFlow.flow = oAuthFlow
s.oauthAuthFlow.info = authInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
s.mutex.Unlock()
return &proto.RequestJWTAuthResponse{
VerificationURI: authInfo.VerificationURI,
VerificationURIComplete: authInfo.VerificationURIComplete,
UserCode: authInfo.UserCode,
DeviceCode: authInfo.DeviceCode,
ExpiresIn: int64(authInfo.ExpiresIn),
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
// WaitJWTToken waits for JWT authentication completion
func (s *Server) WaitJWTToken(
ctx context.Context,
req *proto.WaitJWTTokenRequest,
) (*proto.WaitJWTTokenResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
oAuthFlow := s.oauthAuthFlow.flow
authInfo := s.oauthAuthFlow.info
s.mutex.Unlock()
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
}
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
}
token := tokenInfo.GetTokenToUse()
jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 {
s.jwtCache.store(token, jwtCacheTTL)
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
} else {
log.Debug("JWT caching disabled, not storing token")
}
s.mutex.Lock()
s.oauthAuthFlow = oauthAuthFlow{}
s.mutex.Unlock()
return &proto.WaitJWTTokenResponse{
Token: tokenInfo.GetTokenToUse(),
TokenType: tokenInfo.TokenType,
ExpiresIn: int64(tokenInfo.ExpiresIn),
}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}
func (s *Server) runProbes() { func (s *Server) runProbes() {
if s.connectClient == nil { if s.connectClient == nil {
return return
@@ -1191,13 +1344,18 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
} }
disableSSHAuth := false
if s.config.DisableSSHAuth != nil {
disableSSHAuth = *s.config.DisableSSHAuth
}
return &proto.GetConfigResponse{ return &proto.GetConfigResponse{
ManagementUrl: managementURL.String(), ManagementUrl: managementURL.String(),
PreSharedKey: preSharedKey, PreSharedKey: preSharedKey,
AdminURL: adminURL.String(), AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface, InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort), WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU), Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect, DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed, ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled, RosenpassEnabled: cfg.RosenpassEnabled,
@@ -1214,6 +1372,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHSFTP: enableSSHSFTP, EnableSSHSFTP: enableSSHSFTP,
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding, EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding, EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
}, nil }, nil
} }

View File

@@ -6,9 +6,11 @@ import (
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/ssh/config"
) )
func registerStates(mgr *statemanager.Manager) { func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{}) mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&config.ShutdownState{})
} }

View File

@@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns" "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops" "github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/ssh/config"
) )
func registerStates(mgr *statemanager.Manager) { func registerStates(mgr *statemanager.Manager) {
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&systemops.ShutdownState{}) mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&nftables.ShutdownState{}) mgr.RegisterState(&nftables.ShutdownState{})
mgr.RegisterState(&iptables.ShutdownState{}) mgr.RegisterState(&iptables.ShutdownState{})
mgr.RegisterState(&config.ShutdownState{})
} }

View File

@@ -21,6 +21,8 @@ import (
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
) )
const ( const (
@@ -40,12 +42,10 @@ type Client struct {
windowsStdinMode uint32 // nolint:unused windowsStdinMode uint32 // nolint:unused
} }
// Close terminates the SSH connection
func (c *Client) Close() error { func (c *Client) Close() error {
return c.client.Close() return c.client.Close()
} }
// OpenTerminal opens an interactive terminal session
func (c *Client) OpenTerminal(ctx context.Context) error { func (c *Client) OpenTerminal(ctx context.Context) error {
session, err := c.client.NewSession() session, err := c.client.NewSession()
if err != nil { if err != nil {
@@ -259,43 +259,29 @@ func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error
return session, cleanup, nil return session, cleanup, nil
} }
// Dial connects to the given ssh server with proper host key verification // getDefaultDaemonAddr returns the daemon address from environment or default for the OS
func Dial(ctx context.Context, addr, user string) (*Client, error) { func getDefaultDaemonAddr() string {
hostKeyCallback, err := createHostKeyCallback(addr) if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
if err != nil { return addr
return nil, fmt.Errorf("create host key callback: %w", err)
} }
if runtime.GOOS == "windows" {
config := &ssh.ClientConfig{ return DefaultDaemonAddrWindows
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: hostKeyCallback,
} }
return DefaultDaemonAddr
return dial(ctx, "tcp", addr, config)
}
// DialInsecure connects to the given ssh server without host key verification (for testing only)
func DialInsecure(ctx context.Context, addr, user string) (*Client, error) {
config := &ssh.ClientConfig{
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 - Only used for tests
}
return dial(ctx, "tcp", addr, config)
} }
// DialOptions contains options for SSH connections // DialOptions contains options for SSH connections
type DialOptions struct { type DialOptions struct {
KnownHostsFile string KnownHostsFile string
IdentityFile string IdentityFile string
DaemonAddr string DaemonAddr string
SkipCachedToken bool
InsecureSkipVerify bool
} }
// DialWithOptions connects to the given ssh server with specified options // Dial connects to the given ssh server with specified options
func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) { func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
hostKeyCallback, err := createHostKeyCallbackWithOptions(addr, opts) hostKeyCallback, err := createHostKeyCallback(opts)
if err != nil { if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err) return nil, fmt.Errorf("create host key callback: %w", err)
} }
@@ -306,7 +292,6 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
HostKeyCallback: hostKeyCallback, HostKeyCallback: hostKeyCallback,
} }
// Add SSH key authentication if identity file is specified
if opts.IdentityFile != "" { if opts.IdentityFile != "" {
authMethod, err := createSSHKeyAuth(opts.IdentityFile) authMethod, err := createSSHKeyAuth(opts.IdentityFile)
if err != nil { if err != nil {
@@ -315,11 +300,16 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
config.Auth = append(config.Auth, authMethod) config.Auth = append(config.Auth, authMethod)
} }
return dial(ctx, "tcp", addr, config) daemonAddr := opts.DaemonAddr
if daemonAddr == "" {
daemonAddr = getDefaultDaemonAddr()
}
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
} }
// dial establishes an SSH connection // dialSSH establishes an SSH connection without JWT authentication
func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) { func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
dialer := &net.Dialer{} dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, network, addr) conn, err := dialer.DialContext(ctx, network, addr)
if err != nil { if err != nil {
@@ -340,143 +330,84 @@ func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (
}, nil }, nil
} }
// createHostKeyCallback creates a host key verification callback that checks daemon first, then known_hosts files // dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
func createHostKeyCallback(addr string) (ssh.HostKeyCallback, error) { func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
daemonAddr := os.Getenv("NB_DAEMON_ADDR") host, portStr, err := net.SplitHostPort(addr)
if daemonAddr == "" { if err != nil {
if runtime.GOOS == "windows" { return nil, fmt.Errorf("parse address %s: %w", addr, err)
daemonAddr = DefaultDaemonAddrWindows
} else {
daemonAddr = DefaultDaemonAddr
}
} }
return createHostKeyCallbackWithDaemonAddr(addr, daemonAddr) port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
return nil, fmt.Errorf("SSH server detection failed: %w", err)
}
if !serverType.RequiresJWT() {
return dialSSH(ctx, network, addr, config)
}
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel()
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
if err != nil {
return nil, fmt.Errorf("request JWT token: %w", err)
}
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
return dialSSH(ctx, network, addr, configWithJWT)
} }
// createHostKeyCallbackWithDaemonAddr creates a host key verification callback with specified daemon address // requestJWTToken requests a JWT token from the NetBird daemon
func createHostKeyCallbackWithDaemonAddr(addr, daemonAddr string) (ssh.HostKeyCallback, error) { func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error { conn, err := connectToDaemon(daemonAddr)
// First try to get host key from NetBird daemon if err != nil {
if err := verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr); err == nil { return "", fmt.Errorf("connect to daemon: %w", err)
return nil }
} defer conn.Close()
// Fallback to known_hosts files client := proto.NewDaemonServiceClient(conn)
knownHostsFiles := getKnownHostsFiles() return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache)
var hostKeyCallbacks []ssh.HostKeyCallback
for _, file := range knownHostsFiles {
if callback, err := knownhosts.New(file); err == nil {
hostKeyCallbacks = append(hostKeyCallbacks, callback)
}
}
// Try each known_hosts callback
for _, callback := range hostKeyCallbacks {
if err := callback(hostname, remote, key); err == nil {
return nil
}
}
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
}, nil
} }
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon // verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error { func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
client, err := connectToDaemon(daemonAddr) conn, err := connectToDaemon(daemonAddr)
if err != nil { if err != nil {
return err return err
} }
defer func() { defer func() {
if err := client.Close(); err != nil { if err := conn.Close(); err != nil {
log.Debugf("daemon connection close error: %v", err) log.Debugf("daemon connection close error: %v", err)
} }
}() }()
addresses := buildAddressList(hostname, remote) client := proto.NewDaemonServiceClient(conn)
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses) verifier := nbssh.NewDaemonHostKeyVerifier(client)
callback := nbssh.CreateHostKeyCallback(verifier)
return verifyKeyWithDaemon(client, addresses, key) return callback(hostname, remote, key)
} }
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) { func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second) addr := strings.TrimPrefix(daemonAddr, "tcp://")
defer cancel()
conn, err := grpc.DialContext( conn, err := grpc.NewClient(
ctx, addr,
strings.TrimPrefix(daemonAddr, "tcp://"),
grpc.WithTransportCredentials(insecure.NewCredentials()), grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
) )
if err != nil { if err != nil {
log.Debugf("failed to connect to NetBird daemon at %s: %v", daemonAddr, err) log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err) return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
} }
return conn, nil return conn, nil
} }
func buildAddressList(hostname string, remote net.Addr) []string {
addresses := []string{hostname}
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
if host != hostname {
addresses = append(addresses, host)
}
}
return addresses
}
func verifyKeyWithDaemon(conn *grpc.ClientConn, addresses []string, key ssh.PublicKey) error {
client := proto.NewDaemonServiceClient(conn)
for _, addr := range addresses {
if err := checkAddressKey(client, addr, key); err == nil {
return nil
}
}
return fmt.Errorf("SSH host key not found or does not match in NetBird daemon")
}
func checkAddressKey(client proto.DaemonServiceClient, addr string, key ssh.PublicKey) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: addr,
})
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
if err != nil {
log.Debugf("daemon query error for %s: %v", addr, err)
return err
}
if !response.GetFound() {
log.Debugf("SSH host key not found in daemon for address: %s", addr)
return fmt.Errorf("key not found")
}
return compareKeys(response.GetSshHostKey(), key, addr)
}
func compareKeys(storedKeyData []byte, presentedKey ssh.PublicKey, addr string) error {
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
if err != nil {
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
return err
}
if presentedKey.Type() == storedKey.Type() && string(presentedKey.Marshal()) == string(storedKey.Marshal()) {
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
return nil
}
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), presentedKey.Type())
return fmt.Errorf("key mismatch")
}
// getKnownHostsFiles returns paths to known_hosts files in order of preference // getKnownHostsFiles returns paths to known_hosts files in order of preference
func getKnownHostsFiles() []string { func getKnownHostsFiles() []string {
var files []string var files []string
@@ -503,8 +434,12 @@ func getKnownHostsFiles() []string {
return files return files
} }
// createHostKeyCallbackWithOptions creates a host key verification callback with custom options // createHostKeyCallback creates a host key verification callback
func createHostKeyCallbackWithOptions(addr string, opts DialOptions) (ssh.HostKeyCallback, error) { func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
if opts.InsecureSkipVerify {
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error { return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil { if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
return nil return nil

View File

@@ -7,29 +7,36 @@ import (
"io" "io"
"net" "net"
"os" "os"
"os/exec"
"os/user" "os/user"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
"time" "time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh" cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh" "github.com/netbirdio/netbird/client/ssh"
sshserver "github.com/netbirdio/netbird/client/ssh/server" sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
) )
// TestMain handles package-level setup and cleanup // TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests // Run tests
code := m.Run() code := m.Run()
// Cleanup any created test users // Cleanup any created test users
cleanupTestUsers() testutil.CleanupTestUsers()
os.Exit(code) os.Exit(code)
} }
@@ -39,19 +46,14 @@ func TestSSHClient_DialWithKey(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create and start server // Create and start server
server := sshserver.New(hostKey) serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server) serverAddr := sshserver.StartTestServer(t, server)
defer func() { defer func() {
err := server.Stop() err := server.Stop()
@@ -62,8 +64,10 @@ func TestSSHClient_DialWithKey(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
currentUser := getCurrentUsername() currentUser := testutil.GetTestUsername(t)
client, err := DialInsecure(ctx, serverAddr, currentUser) client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
err := client.Close() err := client.Close()
@@ -75,7 +79,7 @@ func TestSSHClient_DialWithKey(t *testing.T) {
} }
func TestSSHClient_CommandExecution(t *testing.T) { func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && isCI() { if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues") t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
} }
@@ -129,20 +133,16 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
}() }()
// Generate client key for multiple connections // Generate client key for multiple connections
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("multi-peer", string(clientPubKey))
require.NoError(t, err)
const numClients = 3 const numClients = 3
clients := make([]*Client, numClients) clients := make([]*Client, numClients)
for i := 0; i < numClients; i++ { for i := 0; i < numClients; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
currentUser := getCurrentUsername() currentUser := testutil.GetTestUsername(t)
client, err := DialInsecure(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i)) client, err := Dial(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i), DialOptions{
InsecureSkipVerify: true,
})
cancel() cancel()
require.NoError(t, err, "Client %d should connect successfully", i) require.NoError(t, err, "Client %d should connect successfully", i)
clients[i] = client clients[i] = client
@@ -161,19 +161,14 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("cancel-peer", string(clientPubKey))
require.NoError(t, err)
t.Run("connection with short timeout", func(t *testing.T) { t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond) ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel() defer cancel()
currentUser := getCurrentUsername() currentUser := testutil.GetTestUsername(t)
_, err = DialInsecure(ctx, serverAddr, currentUser) _, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil { if err != nil {
// Check for actual timeout-related errors rather than string matching // Check for actual timeout-related errors rather than string matching
assert.True(t, assert.True(t,
@@ -187,8 +182,10 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
t.Run("command execution cancellation", func(t *testing.T) { t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
currentUser := getCurrentUsername() currentUser := testutil.GetTestUsername(t)
client, err := DialInsecure(ctx, serverAddr, currentUser) client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
if err := client.Close(); err != nil { if err := client.Close(); err != nil {
@@ -214,7 +211,11 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
server := sshserver.New(hostKey) serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests server.SetAllowRootLogin(true) // Allow root/admin login for tests
serverAddr := sshserver.StartTestServer(t, server) serverAddr := sshserver.StartTestServer(t, server)
@@ -226,10 +227,12 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
currentUser := getCurrentUsername() currentUser := testutil.GetTestUsername(t)
t.Run("any key succeeds in no-auth mode", func(t *testing.T) { t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
client, err := DialInsecure(ctx, serverAddr, currentUser) client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
assert.NoError(t, err) assert.NoError(t, err)
if client != nil { if client != nil {
require.NoError(t, client.Close(), "Client should close without error") require.NoError(t, client.Close(), "Client should close without error")
@@ -282,24 +285,22 @@ func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Clie
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) serverConfig := &sshserver.Config{
require.NoError(t, err) HostKeyPEM: hostKey,
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey) JWT: nil,
require.NoError(t, err) }
server := sshserver.New(serverConfig)
server := sshserver.New(hostKey)
server.SetAllowRootLogin(true) // Allow root/admin login for tests server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server) serverAddr := sshserver.StartTestServer(t, server)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel() defer cancel()
currentUser := getCurrentUsername() currentUser := testutil.GetTestUsername(t)
client, err := DialInsecure(ctx, serverAddr, currentUser) client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err) require.NoError(t, err)
return server, serverAddr, client return server, serverAddr, client
@@ -361,18 +362,14 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) serverConfig := &sshserver.Config{
require.NoError(t, err) HostKeyPEM: hostKey,
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey) JWT: nil,
require.NoError(t, err) }
server := sshserver.New(serverConfig)
server := sshserver.New(hostKey)
server.SetAllowLocalPortForwarding(true) server.SetAllowLocalPortForwarding(true)
server.SetAllowRootLogin(true) // Allow root/admin login for tests server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server) serverAddr := sshserver.StartTestServer(t, server)
defer func() { defer func() {
err := server.Stop() err := server.Stop()
@@ -387,11 +384,13 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Skip if running as system account that can't do port forwarding // Skip if running as system account that can't do port forwarding
if isSystemAccount(realUser) { if testutil.IsSystemAccount(realUser) {
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser) t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
} }
client, err := DialInsecure(ctx, serverAddr, realUser) client, err := Dial(ctx, serverAddr, realUser, DialOptions{
InsecureSkipVerify: true, // Skip host key verification for test
})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
if err := client.Close(); err != nil { if err := client.Close(); err != nil {
@@ -478,180 +477,6 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
assert.Equal(t, expectedResponse, string(response)) assert.Equal(t, expectedResponse, string(response))
} }
// getCurrentUsername returns the current username for SSH connections
func getCurrentUsername() string {
if runtime.GOOS == "windows" {
if currentUser, err := user.Current(); err == nil {
// Check if this is a system account that can't authenticate
if isSystemAccount(currentUser.Username) {
// In CI environments, create a test user; otherwise try Administrator
if isCI() {
if testUser := getOrCreateTestUser(); testUser != "" {
return testUser
}
} else {
// Try Administrator first for local development
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := getOrCreateTestUser(); testUser != "" {
return testUser
}
}
}
// On Windows, return the full domain\username for proper authentication
return currentUser.Username
}
}
if username := os.Getenv("USER"); username != "" {
return username
}
if currentUser, err := user.Current(); err == nil {
return currentUser.Username
}
return "test-user"
}
// isCI checks if we're running in GitHub Actions CI
func isCI() bool {
// Check standard CI environment variables
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// getOrCreateTestUser creates a test user on Windows if needed
func getOrCreateTestUser() string {
testUsername := "netbird-test-user"
// Check if user already exists
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
// Try to create the user using PowerShell
if createWindowsTestUser(testUsername) {
// Register cleanup for the test user
registerTestUserCleanup(testUsername)
return testUsername
}
return ""
}
var createdTestUsers = make(map[string]bool)
var testUsersToCleanup []string
// registerTestUserCleanup registers a test user for cleanup
func registerTestUserCleanup(username string) {
if !createdTestUsers[username] {
createdTestUsers[username] = true
testUsersToCleanup = append(testUsersToCleanup, username)
}
}
// cleanupTestUsers removes all created test users
func cleanupTestUsers() {
for _, username := range testUsersToCleanup {
removeWindowsTestUser(username)
}
testUsersToCleanup = nil
createdTestUsers = make(map[string]bool)
}
// removeWindowsTestUser removes a local user on Windows using PowerShell
func removeWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
// PowerShell command to remove a local user
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// createWindowsTestUser creates a local user on Windows using PowerShell
func createWindowsTestUser(username string) bool {
if runtime.GOOS != "windows" {
return false
}
// PowerShell command to create a local user
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
log.Printf("Test user creation result: %s", string(output))
return true
}
// isSystemAccount checks if the user is a system account that can't authenticate
func isSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding // getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
func getRealCurrentUser() (string, error) { func getRealCurrentUser() (string, error) {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {

167
client/ssh/common.go Normal file
View File

@@ -0,0 +1,167 @@
package ssh
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/proto"
)
const (
NetBirdSSHConfigFile = "99-netbird.conf"
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
WindowsSSHConfigDir = "ssh/ssh_config.d"
)
var (
// ErrPeerNotFound indicates the peer was not found in the network
ErrPeerNotFound = errors.New("peer not found in network")
// ErrNoStoredKey indicates the peer has no stored SSH host key
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
)
// HostKeyVerifier provides SSH host key verification
type HostKeyVerifier interface {
VerifySSHHostKey(peerAddress string, key []byte) error
}
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
type DaemonHostKeyVerifier struct {
client proto.DaemonServiceClient
}
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
return &DaemonHostKeyVerifier{
client: client,
}
}
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: peerAddress,
})
if err != nil {
return err
}
if !response.GetFound() {
return ErrPeerNotFound
}
storedKeyData := response.GetSshHostKey()
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
}
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool) (string, error) {
authResponse, err := client.RequestJWTAuth(ctx, &proto.RequestJWTAuthRequest{})
if err != nil {
return "", fmt.Errorf("request JWT auth: %w", err)
}
if useCache && authResponse.CachedToken != "" {
log.Debug("Using cached authentication token")
return authResponse.CachedToken, nil
}
if stderr != nil {
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
if authResponse.UserCode != "" {
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
}
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
}
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
DeviceCode: authResponse.DeviceCode,
UserCode: authResponse.UserCode,
})
if err != nil {
return "", fmt.Errorf("wait for JWT token: %w", err)
}
if stdout != nil {
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
}
return tokenResponse.Token, nil
}
// VerifyHostKey verifies an SSH host key against stored peer key data.
// Returns nil only if the presented key matches the stored key.
// Returns ErrNoStoredKey if storedKeyData is empty.
// Returns an error if the keys don't match or if parsing fails.
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
if len(storedKeyData) == 0 {
return ErrNoStoredKey
}
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
if err != nil {
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
}
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
}
return nil
}
// AddJWTAuth prepends JWT password authentication to existing auth methods.
// This ensures JWT auth is tried first while preserving any existing auth methods.
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
configWithJWT := *config
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
return &configWithJWT
}
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
// It tries multiple addresses (hostname, IP) for the peer before failing.
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
addresses := buildAddressList(hostname, remote)
presentedKey := key.Marshal()
for _, addr := range addresses {
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
if errors.Is(err, ErrPeerNotFound) {
// Try other addresses for this peer
continue
}
return err
}
// Verified
return nil
}
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
}
}
// buildAddressList creates a list of addresses to check for host key verification.
// It includes the original hostname and extracts the host part from the remote address if different.
func buildAddressList(hostname string, remote net.Addr) []string {
addresses := []string{hostname}
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
if host != hostname {
addresses = append(addresses, host)
}
}
return addresses
}

View File

@@ -1,7 +1,6 @@
package config package config
import ( import (
"bufio"
"context" "context"
"fmt" "fmt"
"os" "os"
@@ -12,50 +11,41 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
) )
const ( const (
// EnvDisableSSHConfig is the environment variable to disable SSH config management
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG" EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
// EnvForceSSHConfig is the environment variable to force SSH config generation even with many peers
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG" EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
// MaxPeersForSSHConfig is the default maximum number of peers before SSH config generation is disabled
MaxPeersForSSHConfig = 200 MaxPeersForSSHConfig = 200
// fileWriteTimeout is the timeout for file write operations
fileWriteTimeout = 2 * time.Second fileWriteTimeout = 2 * time.Second
) )
// isSSHConfigDisabled checks if SSH config management is disabled via environment variable
func isSSHConfigDisabled() bool { func isSSHConfigDisabled() bool {
value := os.Getenv(EnvDisableSSHConfig) value := os.Getenv(EnvDisableSSHConfig)
if value == "" { if value == "" {
return false return false
} }
// Parse as boolean, default to true if non-empty but invalid
disabled, err := strconv.ParseBool(value) disabled, err := strconv.ParseBool(value)
if err != nil { if err != nil {
// If not a valid boolean, treat any non-empty value as true
return true return true
} }
return disabled return disabled
} }
// isSSHConfigForced checks if SSH config generation is forced via environment variable
func isSSHConfigForced() bool { func isSSHConfigForced() bool {
value := os.Getenv(EnvForceSSHConfig) value := os.Getenv(EnvForceSSHConfig)
if value == "" { if value == "" {
return false return false
} }
// Parse as boolean, default to true if non-empty but invalid
forced, err := strconv.ParseBool(value) forced, err := strconv.ParseBool(value)
if err != nil { if err != nil {
// If not a valid boolean, treat any non-empty value as true
return true return true
} }
return forced return forced
@@ -92,85 +82,55 @@ func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error
} }
} }
// writeFileOperationWithTimeout performs a file operation with timeout
func writeFileOperationWithTimeout(filename string, operation func() error) error {
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- operation()
}()
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
}
}
// Manager handles SSH client configuration for NetBird peers // Manager handles SSH client configuration for NetBird peers
type Manager struct { type Manager struct {
sshConfigDir string sshConfigDir string
sshConfigFile string sshConfigFile string
knownHostsDir string
knownHostsFile string
userKnownHosts string
} }
// PeerHostKey represents a peer's SSH host key information // PeerSSHInfo represents a peer's SSH configuration information
type PeerHostKey struct { type PeerSSHInfo struct {
Hostname string Hostname string
IP string IP string
FQDN string FQDN string
HostKey ssh.PublicKey
} }
// NewManager creates a new SSH config manager // New creates a new SSH config manager
func NewManager() *Manager { func New() *Manager {
sshConfigDir, knownHostsDir := getSystemSSHPaths() sshConfigDir := getSystemSSHConfigDir()
return &Manager{ return &Manager{
sshConfigDir: sshConfigDir, sshConfigDir: sshConfigDir,
sshConfigFile: "99-netbird.conf", sshConfigFile: nbssh.NetBirdSSHConfigFile,
knownHostsDir: knownHostsDir,
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
} }
} }
// getSystemSSHPaths returns platform-specific SSH configuration paths // getSystemSSHConfigDir returns platform-specific SSH configuration directory
func getSystemSSHPaths() (configDir, knownHostsDir string) { func getSystemSSHConfigDir() string {
switch runtime.GOOS { if runtime.GOOS == "windows" {
case "windows": return getWindowsSSHConfigDir()
configDir, knownHostsDir = getWindowsSSHPaths()
default:
// Unix-like systems (Linux, macOS, etc.)
configDir = "/etc/ssh/ssh_config.d"
knownHostsDir = "/etc/ssh/ssh_known_hosts.d"
} }
return configDir, knownHostsDir return nbssh.UnixSSHConfigDir
} }
func getWindowsSSHPaths() (configDir, knownHostsDir string) { func getWindowsSSHConfigDir() string {
programData := os.Getenv("PROGRAMDATA") programData := os.Getenv("PROGRAMDATA")
if programData == "" { if programData == "" {
programData = `C:\ProgramData` programData = `C:\ProgramData`
} }
configDir = filepath.Join(programData, "ssh", "ssh_config.d") return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
knownHostsDir = filepath.Join(programData, "ssh", "ssh_known_hosts.d")
return configDir, knownHostsDir
} }
// SetupSSHClientConfig creates SSH client configuration for NetBird peers // SetupSSHClientConfig creates SSH client configuration for NetBird peers
func (m *Manager) SetupSSHClientConfig(peerKeys []PeerHostKey) error { func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
if !shouldGenerateSSHConfig(len(peerKeys)) { if !shouldGenerateSSHConfig(len(peers)) {
m.logSkipReason(len(peerKeys)) m.logSkipReason(len(peers))
return nil return nil
} }
knownHostsPath := m.getKnownHostsPath() sshConfig, err := m.buildSSHConfig(peers)
sshConfig := m.buildSSHConfig(peerKeys, knownHostsPath) if err != nil {
return fmt.Errorf("build SSH config: %w", err)
}
return m.writeSSHConfig(sshConfig) return m.writeSSHConfig(sshConfig)
} }
@@ -183,21 +143,24 @@ func (m *Manager) logSkipReason(peerCount int) {
} }
} }
func (m *Manager) getKnownHostsPath() string { func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
knownHostsPath, err := m.setupKnownHostsFile()
if err != nil {
log.Warnf("Failed to setup known_hosts file: %v", err)
return "/dev/null"
}
return knownHostsPath
}
func (m *Manager) buildSSHConfig(peerKeys []PeerHostKey, knownHostsPath string) string {
sshConfig := m.buildConfigHeader() sshConfig := m.buildConfigHeader()
for _, peer := range peerKeys {
sshConfig += m.buildPeerConfig(peer, knownHostsPath) var allHostPatterns []string
for _, peer := range peers {
hostPatterns := m.buildHostPatterns(peer)
allHostPatterns = append(allHostPatterns, hostPatterns...)
} }
return sshConfig
if len(allHostPatterns) > 0 {
peerConfig, err := m.buildPeerConfig(allHostPatterns)
if err != nil {
return "", err
}
sshConfig += peerConfig
}
return sshConfig, nil
} }
func (m *Manager) buildConfigHeader() string { func (m *Manager) buildConfigHeader() string {
@@ -209,25 +172,49 @@ func (m *Manager) buildConfigHeader() string {
"#\n\n" "#\n\n"
} }
func (m *Manager) buildPeerConfig(peer PeerHostKey, knownHostsPath string) string { func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
hostPatterns := m.buildHostPatterns(peer) uniquePatterns := make(map[string]bool)
if len(hostPatterns) == 0 { var deduplicatedPatterns []string
return "" for _, pattern := range allHostPatterns {
if !uniquePatterns[pattern] {
uniquePatterns[pattern] = true
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
}
} }
hostLine := strings.Join(hostPatterns, " ") execPath, err := m.getNetBirdExecutablePath()
if err != nil {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine) config := fmt.Sprintf("Host %s\n", hostLine)
config += " # NetBird peer-specific configuration\n"
config += " PreferredAuthentications password,publickey,keyboard-interactive\n" if runtime.GOOS == "windows" {
config += " PasswordAuthentication yes\n" config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
config += " PubkeyAuthentication yes\n" } else {
config += " BatchMode no\n" config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
config += m.buildHostKeyConfig(knownHostsPath) }
config += " LogLevel ERROR\n\n" config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
return config config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
} }
func (m *Manager) buildHostPatterns(peer PeerHostKey) []string { func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
var hostPatterns []string var hostPatterns []string
if peer.IP != "" { if peer.IP != "" {
hostPatterns = append(hostPatterns, peer.IP) hostPatterns = append(hostPatterns, peer.IP)
@@ -241,280 +228,55 @@ func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
return hostPatterns return hostPatterns
} }
func (m *Manager) buildHostKeyConfig(knownHostsPath string) string {
if knownHostsPath == "/dev/null" {
return " StrictHostKeyChecking no\n" +
" UserKnownHostsFile /dev/null\n"
}
return " StrictHostKeyChecking yes\n" +
fmt.Sprintf(" UserKnownHostsFile %s\n", knownHostsPath)
}
func (m *Manager) writeSSHConfig(sshConfig string) error { func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
log.Warnf("Failed to create SSH config directory %s: %v", m.sshConfigDir, err) return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
return m.setupUserConfig(sshConfig)
} }
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil { if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
log.Warnf("Failed to write SSH config file %s: %v", sshConfigPath, err) return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
return m.setupUserConfig(sshConfig)
} }
log.Infof("Created NetBird SSH client config: %s", sshConfigPath) log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
return nil return nil
} }
// setupUserConfig creates SSH config in user's directory as fallback
func (m *Manager) setupUserConfig(sshConfig string) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("get user home directory: %w", err)
}
userSSHDir := filepath.Join(homeDir, ".ssh")
userConfigPath := filepath.Join(userSSHDir, "config")
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
return fmt.Errorf("create user SSH directory: %w", err)
}
// Check if NetBird config already exists in user config
exists, err := m.configExists(userConfigPath)
if err != nil {
return fmt.Errorf("check existing config: %w", err)
}
if exists {
log.Debugf("NetBird SSH config already exists in %s", userConfigPath)
return nil
}
// Append NetBird config to user's SSH config with timeout
if err := writeFileOperationWithTimeout(userConfigPath, func() error {
file, err := os.OpenFile(userConfigPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return fmt.Errorf("open user SSH config: %w", err)
}
defer func() {
if err := file.Close(); err != nil {
log.Debugf("user SSH config file close error: %v", err)
}
}()
if _, err := fmt.Fprintf(file, "\n%s", sshConfig); err != nil {
return fmt.Errorf("write to user SSH config: %w", err)
}
return nil
}); err != nil {
return err
}
log.Infof("Added NetBird SSH config to user config: %s", userConfigPath)
return nil
}
// configExists checks if NetBird SSH config already exists
func (m *Manager) configExists(configPath string) (bool, error) {
file, err := os.Open(configPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("open SSH config file: %w", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.Contains(line, "NetBird SSH client configuration") {
return true, nil
}
}
return false, scanner.Err()
}
// RemoveSSHClientConfig removes NetBird SSH configuration // RemoveSSHClientConfig removes NetBird SSH configuration
func (m *Manager) RemoveSSHClientConfig() error { func (m *Manager) RemoveSSHClientConfig() error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
err := os.Remove(sshConfigPath)
// Remove system-wide config if it exists if err != nil && !os.IsNotExist(err) {
if err := os.Remove(sshConfigPath); err != nil && !os.IsNotExist(err) { return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
log.Warnf("Failed to remove system SSH config %s: %v", sshConfigPath, err) }
} else if err == nil { if err == nil {
log.Infof("Removed NetBird SSH config: %s", sshConfigPath) log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
} }
// Also try to clean up user config
homeDir, err := os.UserHomeDir()
if err != nil {
log.Debugf("failed to get user home directory: %v", err)
return nil
}
userConfigPath := filepath.Join(homeDir, ".ssh", "config")
if err := m.removeFromUserConfig(userConfigPath); err != nil {
log.Warnf("Failed to clean user SSH config: %v", err)
}
return nil return nil
} }
// removeFromUserConfig removes NetBird section from user's SSH config func (m *Manager) getNetBirdExecutablePath() (string, error) {
func (m *Manager) removeFromUserConfig(configPath string) error { execPath, err := os.Executable()
// This is complex to implement safely, so for now just log
// In practice, the system-wide config takes precedence anyway
log.Debugf("NetBird SSH config cleanup from user config not implemented")
return nil
}
// setupKnownHostsFile creates and returns the path to NetBird known_hosts file
func (m *Manager) setupKnownHostsFile() (string, error) {
// Try system-wide known_hosts first
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
if err := os.MkdirAll(m.knownHostsDir, 0755); err == nil {
// Create empty file if it doesn't exist
if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
if err := writeFileWithTimeout(knownHostsPath, []byte("# NetBird SSH known hosts\n"), 0644); err == nil {
log.Debugf("Created NetBird known_hosts file: %s", knownHostsPath)
return knownHostsPath, nil
}
} else if err == nil {
return knownHostsPath, nil
}
}
// Fallback to user directory
homeDir, err := os.UserHomeDir()
if err != nil { if err != nil {
return "", fmt.Errorf("get user home directory: %w", err) return "", fmt.Errorf("retrieve executable path: %w", err)
} }
userSSHDir := filepath.Join(homeDir, ".ssh") realPath, err := filepath.EvalSymlinks(execPath)
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
return "", fmt.Errorf("create user SSH directory: %w", err)
}
userKnownHostsPath := filepath.Join(userSSHDir, m.userKnownHosts)
if _, err := os.Stat(userKnownHostsPath); os.IsNotExist(err) {
if err := writeFileWithTimeout(userKnownHostsPath, []byte("# NetBird SSH known hosts\n"), 0600); err != nil {
return "", fmt.Errorf("create user known_hosts file: %w", err)
}
log.Debugf("Created NetBird user known_hosts file: %s", userKnownHostsPath)
}
return userKnownHostsPath, nil
}
// UpdatePeerHostKeys updates the known_hosts file with peer host keys
func (m *Manager) UpdatePeerHostKeys(peerKeys []PeerHostKey) error {
peerCount := len(peerKeys)
// Check if SSH config should be generated
if !shouldGenerateSSHConfig(peerCount) {
if isSSHConfigDisabled() {
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
} else {
log.Infof("SSH known_hosts update skipped: too many peers (%d > %d). Use %s=true to force.",
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
}
return nil
}
knownHostsPath, err := m.setupKnownHostsFile()
if err != nil { if err != nil {
return fmt.Errorf("setup known_hosts file: %w", err) log.Debugf("symlink resolution failed: %v", err)
return execPath, nil
} }
// Create updated known_hosts content - NetBird file should only contain NetBird entries return realPath, nil
var updatedContent strings.Builder
updatedContent.WriteString("# NetBird SSH known hosts\n")
updatedContent.WriteString("# Generated automatically - do not edit manually\n\n")
// Add new NetBird entries - one entry per peer with all hostnames
for _, peerKey := range peerKeys {
entry := m.formatKnownHostsEntry(peerKey)
updatedContent.WriteString(entry)
updatedContent.WriteString("\n")
}
// Write updated content
if err := writeFileWithTimeout(knownHostsPath, []byte(updatedContent.String()), 0644); err != nil {
return fmt.Errorf("write known_hosts file: %w", err)
}
log.Debugf("Updated NetBird known_hosts with %d peer keys: %s", len(peerKeys), knownHostsPath)
return nil
} }
// formatKnownHostsEntry formats a peer host key as a known_hosts entry // GetSSHConfigDir returns the SSH config directory path
func (m *Manager) formatKnownHostsEntry(peerKey PeerHostKey) string { func (m *Manager) GetSSHConfigDir() string {
hostnames := m.getHostnameVariants(peerKey) return m.sshConfigDir
hostnameList := strings.Join(hostnames, ",")
keyString := string(ssh.MarshalAuthorizedKey(peerKey.HostKey))
keyString = strings.TrimSpace(keyString)
return fmt.Sprintf("%s %s", hostnameList, keyString)
} }
// getHostnameVariants returns all possible hostname variants for a peer // GetSSHConfigFile returns the SSH config file name
func (m *Manager) getHostnameVariants(peerKey PeerHostKey) []string { func (m *Manager) GetSSHConfigFile() string {
var hostnames []string return m.sshConfigFile
// Add IP address
if peerKey.IP != "" {
hostnames = append(hostnames, peerKey.IP)
}
// Add FQDN
if peerKey.FQDN != "" {
hostnames = append(hostnames, peerKey.FQDN)
}
// Add hostname if different from FQDN
if peerKey.Hostname != "" && peerKey.Hostname != peerKey.FQDN {
hostnames = append(hostnames, peerKey.Hostname)
}
// Add bracketed IP for non-standard ports (SSH standard)
if peerKey.IP != "" {
hostnames = append(hostnames, fmt.Sprintf("[%s]:22", peerKey.IP))
hostnames = append(hostnames, fmt.Sprintf("[%s]:22022", peerKey.IP))
}
return hostnames
}
// GetKnownHostsPath returns the path to the NetBird known_hosts file
func (m *Manager) GetKnownHostsPath() (string, error) {
return m.setupKnownHostsFile()
}
// RemoveKnownHostsFile removes the NetBird known_hosts file
func (m *Manager) RemoveKnownHostsFile() error {
// Remove system-wide known_hosts if it exists
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
if err := os.Remove(knownHostsPath); err != nil && !os.IsNotExist(err) {
log.Warnf("Failed to remove system known_hosts %s: %v", knownHostsPath, err)
} else if err == nil {
log.Infof("Removed NetBird known_hosts: %s", knownHostsPath)
}
// Also try to clean up user known_hosts
homeDir, err := os.UserHomeDir()
if err != nil {
log.Debugf("failed to get user home directory: %v", err)
return nil
}
userKnownHostsPath := filepath.Join(homeDir, ".ssh", m.userKnownHosts)
if err := os.Remove(userKnownHostsPath); err != nil && !os.IsNotExist(err) {
log.Warnf("Failed to remove user known_hosts %s: %v", userKnownHostsPath, err)
} else if err == nil {
log.Infof("Removed NetBird user known_hosts: %s", userKnownHostsPath)
}
return nil
} }

View File

@@ -10,81 +10,8 @@ import (
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
) )
func TestManager_UpdatePeerHostKeys(t *testing.T) {
// Create temporary directory for test
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
}
// Generate test host keys
hostKey1, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey1, err := ssh.ParsePrivateKey(hostKey1)
require.NoError(t, err)
hostKey2, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey2, err := ssh.ParsePrivateKey(hostKey2)
require.NoError(t, err)
// Create test peer host keys
peerKeys := []PeerHostKey{
{
Hostname: "peer1",
IP: "100.125.1.1",
FQDN: "peer1.nb.internal",
HostKey: pubKey1.PublicKey(),
},
{
Hostname: "peer2",
IP: "100.125.1.2",
FQDN: "peer2.nb.internal",
HostKey: pubKey2.PublicKey(),
},
}
// Test updating known_hosts
err = manager.UpdatePeerHostKeys(peerKeys)
require.NoError(t, err)
// Verify known_hosts file was created and contains entries
knownHostsPath, err := manager.GetKnownHostsPath()
require.NoError(t, err)
content, err := os.ReadFile(knownHostsPath)
require.NoError(t, err)
contentStr := string(content)
assert.Contains(t, contentStr, "100.125.1.1")
assert.Contains(t, contentStr, "100.125.1.2")
assert.Contains(t, contentStr, "peer1.nb.internal")
assert.Contains(t, contentStr, "peer2.nb.internal")
assert.Contains(t, contentStr, "[100.125.1.1]:22")
assert.Contains(t, contentStr, "[100.125.1.1]:22022")
// Test updating with empty list should preserve structure
err = manager.UpdatePeerHostKeys([]PeerHostKey{})
require.NoError(t, err)
content, err = os.ReadFile(knownHostsPath)
require.NoError(t, err)
assert.Contains(t, string(content), "# NetBird SSH known hosts")
}
func TestManager_SetupSSHClientConfig(t *testing.T) { func TestManager_SetupSSHClientConfig(t *testing.T) {
// Create temporary directory for test // Create temporary directory for test
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test") tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
@@ -93,15 +20,25 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
// Override manager paths to use temp directory // Override manager paths to use temp directory
manager := &Manager{ manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf", sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
} }
// Test SSH config generation with empty peer keys // Test SSH config generation with peers
err = manager.SetupSSHClientConfig(nil) peers := []PeerSSHInfo{
{
Hostname: "peer1",
IP: "100.125.1.1",
FQDN: "peer1.nb.internal",
},
{
Hostname: "peer2",
IP: "100.125.1.2",
FQDN: "peer2.nb.internal",
},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err) require.NoError(t, err)
// Read generated config // Read generated config
@@ -111,134 +48,39 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
configStr := string(content) configStr := string(content)
// Since we now use per-peer configurations instead of domain patterns, // Verify the basic SSH config structure exists
// we should verify the basic SSH config structure exists
assert.Contains(t, configStr, "# NetBird SSH client configuration") assert.Contains(t, configStr, "# NetBird SSH client configuration")
assert.Contains(t, configStr, "Generated automatically - do not edit manually") assert.Contains(t, configStr, "Generated automatically - do not edit manually")
// Should not contain /dev/null since we have a proper known_hosts setup // Check that peer hostnames are included
assert.NotContains(t, configStr, "UserKnownHostsFile /dev/null") assert.Contains(t, configStr, "100.125.1.1")
} assert.Contains(t, configStr, "100.125.1.2")
assert.Contains(t, configStr, "peer1.nb.internal")
assert.Contains(t, configStr, "peer2.nb.internal")
func TestManager_GetHostnameVariants(t *testing.T) { // Check platform-specific UserKnownHostsFile
manager := NewManager()
peerKey := PeerHostKey{
Hostname: "testpeer",
IP: "100.125.1.10",
FQDN: "testpeer.nb.internal",
HostKey: nil, // Not needed for this test
}
variants := manager.getHostnameVariants(peerKey)
expectedVariants := []string{
"100.125.1.10",
"testpeer.nb.internal",
"testpeer",
"[100.125.1.10]:22",
"[100.125.1.10]:22022",
}
assert.ElementsMatch(t, expectedVariants, variants)
}
func TestManager_FormatKnownHostsEntry(t *testing.T) {
manager := NewManager()
// Generate test key
hostKeyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
parsedKey, err := ssh.ParsePrivateKey(hostKeyPEM)
require.NoError(t, err)
peerKey := PeerHostKey{
Hostname: "testpeer",
IP: "100.125.1.10",
FQDN: "testpeer.nb.internal",
HostKey: parsedKey.PublicKey(),
}
entry := manager.formatKnownHostsEntry(peerKey)
// Should contain all hostname variants
assert.Contains(t, entry, "100.125.1.10")
assert.Contains(t, entry, "testpeer.nb.internal")
assert.Contains(t, entry, "testpeer")
assert.Contains(t, entry, "[100.125.1.10]:22")
assert.Contains(t, entry, "[100.125.1.10]:22022")
// Should contain the public key
keyString := string(ssh.MarshalAuthorizedKey(parsedKey.PublicKey()))
keyString = strings.TrimSpace(keyString)
assert.Contains(t, entry, keyString)
// Should be properly formatted (hostnames followed by key)
parts := strings.Fields(entry)
assert.GreaterOrEqual(t, len(parts), 2, "Entry should have hostnames and key parts")
}
func TestManager_DirectoryFallback(t *testing.T) {
// Create temporary directory for test where system dirs will fail
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
// Set HOME to temp directory to control user fallback
t.Setenv("HOME", tempDir)
// Create manager with non-writable system directories
// Use paths that will fail on all systems
var failPath string
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
failPath = "NUL:" // Special device that can't be used as directory on Windows assert.Contains(t, configStr, "UserKnownHostsFile NUL")
} else { } else {
failPath = "/dev/null" // Special device that can't be used as directory on Unix assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
} }
manager := &Manager{
sshConfigDir: failPath + "/ssh_config.d", // Should fail
sshConfigFile: "99-netbird.conf",
knownHostsDir: failPath + "/ssh_known_hosts.d", // Should fail
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
}
// Should fall back to user directory
knownHostsPath, err := manager.setupKnownHostsFile()
require.NoError(t, err)
// Get the actual user home directory as determined by os.UserHomeDir()
userHome, err := os.UserHomeDir()
require.NoError(t, err)
expectedUserPath := filepath.Join(userHome, ".ssh", "known_hosts_netbird")
assert.Equal(t, expectedUserPath, knownHostsPath)
// Verify file was created
_, err = os.Stat(knownHostsPath)
require.NoError(t, err)
} }
func TestGetSystemSSHPaths(t *testing.T) { func TestGetSystemSSHConfigDir(t *testing.T) {
configDir, knownHostsDir := getSystemSSHPaths() configDir := getSystemSSHConfigDir()
// Paths should not be empty // Path should not be empty
assert.NotEmpty(t, configDir) assert.NotEmpty(t, configDir)
assert.NotEmpty(t, knownHostsDir)
// Should be absolute paths // Should be an absolute path
assert.True(t, filepath.IsAbs(configDir)) assert.True(t, filepath.IsAbs(configDir))
assert.True(t, filepath.IsAbs(knownHostsDir))
// On Unix systems, should start with /etc // On Unix systems, should start with /etc
// On Windows, should contain ProgramData // On Windows, should contain ProgramData
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
assert.Contains(t, strings.ToLower(configDir), "programdata") assert.Contains(t, strings.ToLower(configDir), "programdata")
assert.Contains(t, strings.ToLower(knownHostsDir), "programdata")
} else { } else {
assert.Contains(t, configDir, "/etc/ssh") assert.Contains(t, configDir, "/etc/ssh")
assert.Contains(t, knownHostsDir, "/etc/ssh")
} }
} }
@@ -250,46 +92,28 @@ func TestManager_PeerLimit(t *testing.T) {
// Override manager paths to use temp directory // Override manager paths to use temp directory
manager := &Manager{ manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf", sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
} }
// Generate many peer keys (more than limit) // Generate many peers (more than limit)
var peerKeys []PeerHostKey var peers []PeerSSHInfo
for i := 0; i < MaxPeersForSSHConfig+10; i++ { for i := 0; i < MaxPeersForSSHConfig+10; i++ {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) peers = append(peers, PeerSSHInfo{
require.NoError(t, err)
pubKey, err := ssh.ParsePrivateKey(hostKey)
require.NoError(t, err)
peerKeys = append(peerKeys, PeerHostKey{
Hostname: fmt.Sprintf("peer%d", i), Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1), IP: fmt.Sprintf("100.125.1.%d", i%254+1),
FQDN: fmt.Sprintf("peer%d.nb.internal", i), FQDN: fmt.Sprintf("peer%d.nb.internal", i),
HostKey: pubKey.PublicKey(),
}) })
} }
// Test that SSH config generation is skipped when too many peers // Test that SSH config generation is skipped when too many peers
err = manager.SetupSSHClientConfig(peerKeys) err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err) require.NoError(t, err)
// Config should not be created due to peer limit // Config should not be created due to peer limit
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile) configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
_, err = os.Stat(configPath) _, err = os.Stat(configPath)
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers") assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
// Test that known_hosts update is also skipped
err = manager.UpdatePeerHostKeys(peerKeys)
require.NoError(t, err)
// Known hosts should not be created due to peer limit
knownHostsPath := filepath.Join(manager.knownHostsDir, manager.knownHostsFile)
_, err = os.Stat(knownHostsPath)
assert.True(t, os.IsNotExist(err), "Known hosts should not be created with too many peers")
} }
func TestManager_ForcedSSHConfig(t *testing.T) { func TestManager_ForcedSSHConfig(t *testing.T) {
@@ -303,31 +127,22 @@ func TestManager_ForcedSSHConfig(t *testing.T) {
// Override manager paths to use temp directory // Override manager paths to use temp directory
manager := &Manager{ manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"), sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf", sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
} }
// Generate many peer keys (more than limit) // Generate many peers (more than limit)
var peerKeys []PeerHostKey var peers []PeerSSHInfo
for i := 0; i < MaxPeersForSSHConfig+10; i++ { for i := 0; i < MaxPeersForSSHConfig+10; i++ {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) peers = append(peers, PeerSSHInfo{
require.NoError(t, err)
pubKey, err := ssh.ParsePrivateKey(hostKey)
require.NoError(t, err)
peerKeys = append(peerKeys, PeerHostKey{
Hostname: fmt.Sprintf("peer%d", i), Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1), IP: fmt.Sprintf("100.125.1.%d", i%254+1),
FQDN: fmt.Sprintf("peer%d.nb.internal", i), FQDN: fmt.Sprintf("peer%d.nb.internal", i),
HostKey: pubKey.PublicKey(),
}) })
} }
// Test that SSH config generation is forced despite many peers // Test that SSH config generation is forced despite many peers
err = manager.SetupSSHClientConfig(peerKeys) err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err) require.NoError(t, err)
// Config should be created despite peer limit due to force flag // Config should be created despite peer limit due to force flag

View File

@@ -0,0 +1,22 @@
package config
// ShutdownState represents SSH configuration state that needs to be cleaned up.
type ShutdownState struct {
SSHConfigDir string
SSHConfigFile string
}
// Name returns the state name for the state manager.
func (s *ShutdownState) Name() string {
return "ssh_config_state"
}
// Cleanup removes SSH client configuration files.
func (s *ShutdownState) Cleanup() error {
manager := &Manager{
sshConfigDir: s.SSHConfigDir,
sshConfigFile: s.SSHConfigFile,
}
return manager.RemoveSSHClientConfig()
}

View File

@@ -0,0 +1,99 @@
package detection
import (
"bufio"
"context"
"net"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const (
// ServerIdentifier is the base response for NetBird SSH servers
ServerIdentifier = "NetBird-SSH-Server"
// ProxyIdentifier is the base response for NetBird SSH proxy
ProxyIdentifier = "NetBird-SSH-Proxy"
// JWTRequiredMarker is appended to responses when JWT is required
JWTRequiredMarker = "NetBird-JWT-Required"
// Timeout is the timeout for SSH server detection
Timeout = 5 * time.Second
)
type ServerType string
const (
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
ServerTypeRegular ServerType = "regular"
)
// Dialer provides network connection capabilities
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// RequiresJWT checks if the server type requires JWT authentication
func (s ServerType) RequiresJWT() bool {
return s == ServerTypeNetBirdJWT
}
// ExitCode returns the exit code for the detect command
func (s ServerType) ExitCode() int {
switch s {
case ServerTypeNetBirdJWT:
return 0
case ServerTypeNetBirdNoJWT:
return 1
case ServerTypeRegular:
return 2
default:
return 2
}
}
// DetectSSHServerType detects SSH server type using the provided dialer
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
if err != nil {
log.Debugf("SSH connection failed for detection: %v", err)
return ServerTypeRegular, nil
}
defer conn.Close()
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
log.Debugf("set read deadline: %v", err)
return ServerTypeRegular, nil
}
reader := bufio.NewReader(conn)
serverBanner, err := reader.ReadString('\n')
if err != nil {
log.Debugf("read SSH banner: %v", err)
return ServerTypeRegular, nil
}
serverBanner = strings.TrimSpace(serverBanner)
log.Debugf("SSH server banner: %s", serverBanner)
if !strings.HasPrefix(serverBanner, "SSH-") {
log.Debugf("Invalid SSH banner")
return ServerTypeRegular, nil
}
if !strings.Contains(serverBanner, ServerIdentifier) {
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
return ServerTypeRegular, nil
}
if strings.Contains(serverBanner, JWTRequiredMarker) {
return ServerTypeNetBirdJWT, nil
}
return ServerTypeNetBirdNoJWT, nil
}

359
client/ssh/proxy/proxy.go Normal file
View File

@@ -0,0 +1,359 @@
package proxy
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/version"
)
const (
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
sshConnectionTimeout = 120 * time.Second
// sshHandshakeTimeout is the timeout for SSH handshake completion
sshHandshakeTimeout = 30 * time.Second
jwtAuthErrorMsg = "JWT authentication: %w"
)
type SSHProxy struct {
daemonAddr string
targetHost string
targetPort int
stderr io.Writer
daemonClient proto.DaemonServiceClient
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, fmt.Errorf("connect to daemon: %w", err)
}
return &SSHProxy{
daemonAddr: daemonAddr,
targetHost: targetHost,
targetPort: targetPort,
stderr: stderr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}, nil
}
func (p *SSHProxy) Connect(ctx context.Context) error {
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true)
if err != nil {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
return p.runProxySSHServer(ctx, jwtToken)
}
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{
Handler: func(s ssh.Session) {
p.handleSSHSession(ctx, s, jwtToken)
},
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler,
},
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": func(s ssh.Session) {
p.sftpSubsystemHandler(s, jwtToken)
},
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": p.tcpipForwardHandler,
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
},
Version: serverVersion,
}
hostKey, err := generateHostKey()
if err != nil {
return fmt.Errorf("generate host key: %w", err)
}
sshServer.HostSigners = []ssh.Signer{hostKey}
conn := &stdioConn{
stdin: os.Stdin,
stdout: os.Stdout,
}
sshServer.HandleConn(conn)
return nil
}
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return
}
defer func() { _ = sshClient.Close() }()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
return
}
defer func() { _ = serverSession.Close() }()
serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
ptyReq, winCh, isPty := session.Pty()
if isPty {
_ = serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil)
go func() {
for win := range winCh {
_ = serverSession.WindowChange(win.Height, win.Width)
}
}()
}
if len(session.Command()) > 0 {
_ = serverSession.Run(strings.Join(session.Command(), " "))
return
}
if err = serverSession.Shell(); err == nil {
_ = serverSession.Wait()
}
}
func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
return nil, fmt.Errorf("generate ED25519 key: %w", err)
}
signer, err := cryptossh.ParsePrivateKey(keyPEM)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
return signer, nil
}
type stdioConn struct {
stdin io.Reader
stdout io.Writer
closed bool
mu sync.Mutex
}
func (c *stdioConn) Read(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.EOF
}
c.mu.Unlock()
return c.stdin.Read(b)
}
func (c *stdioConn) Write(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.ErrClosedPipe
}
c.mu.Unlock()
return c.stdout.Write(b)
}
func (c *stdioConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
func (c *stdioConn) LocalAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) SetDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
ctx, cancel := context.WithCancel(s.Context())
defer cancel()
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := sshClient.Close(); err != nil {
log.Debugf("close SSH client: %v", err)
}
}()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := serverSession.Close(); err != nil {
log.Debugf("close server session: %v", err)
}
}()
stdin, stdout, err := p.setupSFTPPipes(serverSession)
if err != nil {
log.Debugf("setup SFTP pipes: %v", err)
_ = s.Exit(1)
return
}
if err := serverSession.RequestSubsystem("sftp"); err != nil {
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
_ = s.Exit(1)
return
}
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
}
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
stdin, err := serverSession.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
}
stdout, err := serverSession.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
}
return stdin, stdout, nil
}
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
copyErrCh := make(chan error, 2)
go func() {
_, err := io.Copy(stdin, s)
if err != nil {
log.Debugf("SFTP client to server copy: %v", err)
}
if err := stdin.Close(); err != nil {
log.Debugf("close stdin: %v", err)
}
copyErrCh <- err
}()
go func() {
_, err := io.Copy(s, stdout)
if err != nil {
log.Debugf("SFTP server to client copy: %v", err)
}
copyErrCh <- err
}()
go func() {
<-ctx.Done()
if err := serverSession.Close(); err != nil {
log.Debugf("force close server session on context cancellation: %v", err)
}
}()
for i := 0; i < 2; i++ {
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
log.Debugf("SFTP copy error: %v", err)
}
}
if err := serverSession.Wait(); err != nil {
log.Debugf("SFTP session ended: %v", err)
}
}
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return false, []byte("port forwarding not supported in proxy")
}
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return true, nil
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
config := &cryptossh.ClientConfig{
User: user,
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
Timeout: sshHandshakeTimeout,
HostKeyCallback: p.verifyHostKey,
}
dialer := &net.Dialer{
Timeout: sshConnectionTimeout,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("connect to server: %w", err)
}
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("SSH handshake: %w", err)
}
return cryptossh.NewClient(clientConn, chans, reqs), nil
}
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
callback := nbssh.CreateHostKeyCallback(verifier)
return callback(hostname, remote, key)
}

View File

@@ -0,0 +1,361 @@
package proxy
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
)
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "ssh" {
if os.Args[2] == "exec" {
if len(os.Args) > 3 {
cmd := os.Args[3]
if cmd == "echo" && len(os.Args) > 4 {
fmt.Fprintln(os.Stdout, os.Args[4])
os.Exit(0)
}
}
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
os.Exit(1)
}
}
code := m.Run()
testutil.CleanupTestUsers()
os.Exit(code)
}
func TestSSHProxy_verifyHostKey(t *testing.T) {
t.Run("calls daemon to verify host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testPubKey, err := nbssh.GeneratePublicKey(testKey)
require.NoError(t, err)
mockDaemon.setHostKey("test-host", testPubKey)
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
assert.NoError(t, err)
})
t.Run("rejects unknown host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
require.NoError(t, err)
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
assert.Error(t, err)
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
})
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
jwtToken string
}
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
key, found := m.hostKeys[req.PeerAddress]
return &proto.GetPeerSSHHostKeyResponse{
Found: found,
SshHostKey: key,
}, nil
}
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
return &proto.RequestJWTAuthResponse{
CachedToken: m.jwtToken,
}, nil
}
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
return &proto.WaitJWTTokenResponse{
Token: m.jwtToken,
}, nil
}
type mockDaemon struct {
addr string
server *grpc.Server
impl *mockDaemonServer
}
func startMockDaemon(t *testing.T) *mockDaemon {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
impl := &mockDaemonServer{
hostKeys: make(map[string][]byte),
jwtToken: "test-jwt-token",
}
grpcServer := grpc.NewServer()
proto.RegisterDaemonServiceServer(grpcServer, impl)
go func() {
_ = grpcServer.Serve(listener)
}()
return &mockDaemon{
addr: listener.Addr().String(),
server: grpcServer,
impl: impl,
}
}
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
m.impl.hostKeys[addr] = pubKey
}
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func (m *mockDaemon) stop() {
if m.server != nil {
m.server.Stop()
}
}
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
t.Helper()
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
require.NoError(t, err)
return pubKey
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -9,7 +9,6 @@ import (
"net" "net"
"os" "os"
"os/exec" "os/exec"
"os/user"
"runtime" "runtime"
"strings" "strings"
"testing" "testing"
@@ -21,15 +20,24 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/testutil"
) )
// TestMain handles package-level setup and cleanup // TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) { func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests // Run tests
code := m.Run() code := m.Run()
// Cleanup any created test users // Cleanup any created test users
cleanupTestUsers() testutil.CleanupTestUsers()
os.Exit(code) os.Exit(code)
} }
@@ -50,13 +58,15 @@ func TestSSHServerCompatibility(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Generate OpenSSH-compatible keys for client // Generate OpenSSH-compatible keys for client
clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey(t) clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err) require.NoError(t, err)
server := New(hostKey) serverConfig := &Config{
server.SetAllowRootLogin(true) // Allow root login for testing HostKeyPEM: hostKey,
err = server.AddAuthorizedKey("test-peer", string(clientPubKeyOpenSSH)) JWT: nil,
require.NoError(t, err) }
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server) serverAddr := StartTestServer(t, server)
defer func() { defer func() {
@@ -73,7 +83,7 @@ func TestSSHServerCompatibility(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
// Get appropriate user for SSH connection (handle system accounts) // Get appropriate user for SSH connection (handle system accounts)
username := getTestUsername(t) username := testutil.GetTestUsername(t)
t.Run("basic command execution", func(t *testing.T) { t.Run("basic command execution", func(t *testing.T) {
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username) testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
@@ -113,7 +123,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username
// testSSHInteractiveCommand tests interactive shell session. // testSSHInteractiveCommand tests interactive shell session.
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) { func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection // Get appropriate user for SSH connection
username := getTestUsername(t) username := testutil.GetTestUsername(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second) ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel() defer cancel()
@@ -178,7 +188,7 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// testSSHPortForwarding tests port forwarding compatibility. // testSSHPortForwarding tests port forwarding compatibility.
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) { func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection // Get appropriate user for SSH connection
username := getTestUsername(t) username := testutil.GetTestUsername(t)
testServer, err := net.Listen("tcp", "127.0.0.1:0") testServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err) require.NoError(t, err)
@@ -401,7 +411,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
t.Skip("Skipping SSH feature compatibility tests in short mode") t.Skip("Skipping SSH feature compatibility tests in short mode")
} }
if runtime.GOOS == "windows" && isCI() { if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues") t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
} }
@@ -438,13 +448,13 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey) serverConfig := &Config{
server.SetAllowRootLogin(true) // Allow root login for testing HostKeyPEM: hostKey,
err = server.AddAuthorizedKey("test-peer", string(clientPubKey)) JWT: nil,
require.NoError(t, err) }
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server) serverAddr := StartTestServer(t, server)
defer func() { defer func() {
@@ -468,7 +478,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
// testCommandWithFlags tests that commands with flags work properly // testCommandWithFlags tests that commands with flags work properly
func testCommandWithFlags(t *testing.T, host, port, keyFile string) { func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection // Get appropriate user for SSH connection
username := getTestUsername(t) username := testutil.GetTestUsername(t)
// Test ls with flags // Test ls with flags
cmd := exec.Command("ssh", cmd := exec.Command("ssh",
@@ -495,7 +505,7 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// testEnvironmentVariables tests that environment is properly set up // testEnvironmentVariables tests that environment is properly set up
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) { func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection // Get appropriate user for SSH connection
username := getTestUsername(t) username := testutil.GetTestUsername(t)
cmd := exec.Command("ssh", cmd := exec.Command("ssh",
"-i", keyFile, "-i", keyFile,
@@ -522,7 +532,7 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// testExitCodes tests that exit codes are properly handled // testExitCodes tests that exit codes are properly handled
func testExitCodes(t *testing.T, host, port, keyFile string) { func testExitCodes(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection // Get appropriate user for SSH connection
username := getTestUsername(t) username := testutil.GetTestUsername(t)
// Test successful command (exit code 0) // Test successful command (exit code 0)
cmd := exec.Command("ssh", cmd := exec.Command("ssh",
@@ -567,7 +577,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
} }
// Get appropriate user for SSH connection // Get appropriate user for SSH connection
username := getTestUsername(t) username := testutil.GetTestUsername(t)
// Set up SSH server with specific security settings // Set up SSH server with specific security settings
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
@@ -575,13 +585,13 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey) serverConfig := &Config{
server.SetAllowRootLogin(true) // Allow root login for testing HostKeyPEM: hostKey,
err = server.AddAuthorizedKey("test-peer", string(clientPubKey)) JWT: nil,
require.NoError(t, err) }
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server) serverAddr := StartTestServer(t, server)
defer func() { defer func() {
@@ -652,7 +662,7 @@ func TestCrossPlatformCompatibility(t *testing.T) {
} }
// Get appropriate user for SSH connection // Get appropriate user for SSH connection
username := getTestUsername(t) username := testutil.GetTestUsername(t)
// Set up SSH server // Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
@@ -660,13 +670,13 @@ func TestCrossPlatformCompatibility(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey) serverConfig := &Config{
server.SetAllowRootLogin(true) // Allow root login for testing HostKeyPEM: hostKey,
err = server.AddAuthorizedKey("test-peer", string(clientPubKey)) JWT: nil,
require.NoError(t, err) }
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server) serverAddr := StartTestServer(t, server)
defer func() { defer func() {
@@ -710,171 +720,3 @@ func TestCrossPlatformCompatibility(t *testing.T) {
t.Logf("Platform command output: %s", outputStr) t.Logf("Platform command output: %s", outputStr)
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output") assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
} }
// getTestUsername returns an appropriate username for testing
func getTestUsername(t *testing.T) string {
if runtime.GOOS == "windows" {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Check if this is a system account that can't authenticate
if isSystemAccount(currentUser.Username) {
// In CI environments, create a test user; otherwise try Administrator
if isCI() {
if testUser := getOrCreateTestUser(t); testUser != "" {
return testUser
}
} else {
// Try Administrator first for local development
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := getOrCreateTestUser(t); testUser != "" {
return testUser
}
}
}
return currentUser.Username
}
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
return currentUser.Username
}
// isCI checks if we're running in a CI environment
func isCI() bool {
// Check standard CI environment variables
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// isSystemAccount checks if the user is a system account that can't authenticate
func isSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
var compatTestCreatedUsers = make(map[string]bool)
var compatTestUsersToCleanup []string
// registerTestUserCleanup registers a test user for cleanup
func registerTestUserCleanup(username string) {
if !compatTestCreatedUsers[username] {
compatTestCreatedUsers[username] = true
compatTestUsersToCleanup = append(compatTestUsersToCleanup, username)
}
}
// cleanupTestUsers removes all created test users
func cleanupTestUsers() {
for _, username := range compatTestUsersToCleanup {
removeWindowsTestUser(username)
}
compatTestUsersToCleanup = nil
compatTestCreatedUsers = make(map[string]bool)
}
// getOrCreateTestUser creates a test user on Windows if needed
func getOrCreateTestUser(t *testing.T) string {
testUsername := "netbird-test-user"
// Check if user already exists
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
// Try to create the user using PowerShell
if createWindowsTestUser(t, testUsername) {
// Register cleanup for the test user
registerTestUserCleanup(testUsername)
return testUsername
}
return ""
}
// removeWindowsTestUser removes a local user on Windows using PowerShell
func removeWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
// PowerShell command to remove a local user
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// createWindowsTestUser creates a local user on Windows using PowerShell
func createWindowsTestUser(t *testing.T, username string) bool {
if runtime.GOOS != "windows" {
return false
}
// PowerShell command to create a local user
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
t.Logf("Test user creation result: %s", string(output))
return true
}

View File

@@ -0,0 +1,610 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
cryptossh "golang.org/x/crypto/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
)
func TestJWTEnforcement(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT enforcement tests in short mode")
}
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
if err != nil {
t.Logf("Detection failed: %v", err)
}
t.Logf("Detected server type: %s", serverType)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
})
t.Run("allows_when_disabled", func(t *testing.T) {
serverConfigNoJWT := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
serverNoJWT := New(serverConfigNoJWT)
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
serverNoJWT.SetAllowRootLogin(true)
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
defer require.NoError(t, serverNoJWT.Stop())
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
require.NoError(t, err)
portNoJWT, err := strconv.Atoi(portStrNoJWT)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
assert.False(t, serverType.RequiresJWT())
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
require.NoError(t, err)
defer client.Close()
})
}
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64RawURLEncode(n),
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func base64RawURLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
// generateValidJWT creates a valid JWT token for testing
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
t.Helper()
addr := net.JoinHostPort(host, strconv.Itoa(port))
ctx := context.Background()
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
InsecureSkipVerify: true,
})
}
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
func TestJWTDetection(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT detection test in short mode")
}
jwksServer, _, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
assert.True(t, serverType.RequiresJWT())
}
func TestJWTFailClose(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT fail-close tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
tokenClaims jwt.MapClaims
}{
{
name: "blocks_token_missing_iat",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
},
},
{
name: "blocks_token_missing_sub",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_iss",
tokenClaims: jwt.MapClaims{
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_aud",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_issuer",
tokenClaims: jwt.MapClaims{
"iss": "wrong-issuer",
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_audience",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": "wrong-audience",
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_expired_token",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(-time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
MaxTokenAge: 3600,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(tokenString),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
assert.Error(t, err, "Authentication should fail (fail-close)")
})
}
}
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
func TestJWTAuthentication(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT authentication tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
token string
wantAuthOK bool
setupServer func(*Server)
testOperation func(*testing.T, *cryptossh.Client, string) error
wantOpSuccess bool
}{
{
name: "allows_shell_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
return session.Shell()
},
wantOpSuccess: true,
},
{
name: "rejects_invalid_token",
token: "invalid",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_shell_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_command_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("ls")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "allows_sftp_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
return session.RequestSubsystem("sftp")
},
wantOpSuccess: true,
},
{
name: "blocks_sftp_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
err = session.RequestSubsystem("sftp")
if err == nil {
err = session.Wait()
}
return err
},
wantOpSuccess: false,
},
{
name: "allows_port_forward_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowRemotePortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: true,
},
{
name: "blocks_port_forward_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowLocalPortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
if tc.setupServer != nil {
tc.setupServer(server)
}
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
var authMethods []cryptossh.AuthMethod
if tc.token == "valid" {
token := generateValidJWT(t, privateKey, issuer, audience)
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(token),
}
} else if tc.token == "invalid" {
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(invalidToken),
}
}
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: authMethods,
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed")
} else if err != nil {
t.Logf("Connection failed as expected: %v", err)
return
}
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
err = tc.testOperation(t, conn, serverAddr)
if tc.wantOpSuccess {
require.NoError(t, err, "Operation should succeed")
} else {
assert.Error(t, err, "Operation should fail")
}
})
}
}

View File

@@ -2,18 +2,27 @@ package server
import ( import (
"context" "context"
"encoding/base64"
"encoding/json"
"errors" "errors"
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"strings"
"sync" "sync"
"time"
"github.com/gliderlabs/ssh" "github.com/gliderlabs/ssh"
gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh" cryptossh "golang.org/x/crypto/ssh"
"golang.zx2c4.com/wireguard/tun/netstack" "golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/version"
) )
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server // DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
@@ -27,6 +36,9 @@ const (
errExitSession = "exit session error: %v" errExitSession = "exit session error: %v"
msgPrivilegedUserDisabled = "privileged user login is disabled" msgPrivilegedUserDisabled = "privileged user login is disabled"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
) )
var ( var (
@@ -69,7 +81,6 @@ func (e *UserNotFoundError) Unwrap() error {
} }
// safeLogCommand returns a safe representation of the command for logging // safeLogCommand returns a safe representation of the command for logging
// Only logs the first argument to avoid leaking sensitive information
func safeLogCommand(cmd []string) string { func safeLogCommand(cmd []string) string {
if len(cmd) == 0 { if len(cmd) == 0 {
return "<empty>" return "<empty>"
@@ -80,17 +91,14 @@ func safeLogCommand(cmd []string) string {
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1) return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
} }
// sshConnectionState tracks the state of an SSH connection
type sshConnectionState struct { type sshConnectionState struct {
hasActivePortForward bool hasActivePortForward bool
username string username string
remoteAddr string remoteAddr string
} }
// Server is the SSH server implementation
type Server struct { type Server struct {
sshServer *ssh.Server sshServer *ssh.Server
authorizedKeys map[string]ssh.PublicKey
mu sync.RWMutex mu sync.RWMutex
hostKeyPEM []byte hostKeyPEM []byte
sessions map[SessionKey]ssh.Session sessions map[SessionKey]ssh.Session
@@ -100,30 +108,53 @@ type Server struct {
allowRemotePortForwarding bool allowRemotePortForwarding bool
allowRootLogin bool allowRootLogin bool
allowSFTP bool allowSFTP bool
jwtEnabled bool
netstackNet *netstack.Net netstackNet *netstack.Net
wgAddress wgaddr.Address wgAddress wgaddr.Address
ifIdx int
remoteForwardListeners map[ForwardKey]net.Listener remoteForwardListeners map[ForwardKey]net.Listener
sshConnections map[*cryptossh.ServerConn]*sshConnectionState sshConnections map[*cryptossh.ServerConn]*sshConnectionState
jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
} }
// New creates an SSH server instance with the provided host key type JWTConfig struct {
func New(hostKeyPEM []byte) *Server { Issuer string
return &Server{ Audience string
KeysLocation string
MaxTokenAge int64
}
// Config contains all SSH server configuration options
type Config struct {
// JWT authentication configuration. If nil, JWT authentication is disabled
JWT *JWTConfig
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{}, mu: sync.RWMutex{},
hostKeyPEM: hostKeyPEM, hostKeyPEM: config.HostKeyPEM,
authorizedKeys: make(map[string]ssh.PublicKey),
sessions: make(map[SessionKey]ssh.Session), sessions: make(map[SessionKey]ssh.Session),
remoteForwardListeners: make(map[ForwardKey]net.Listener), remoteForwardListeners: make(map[ForwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState), sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
} }
return s
} }
// Start runs the SSH server, automatically detecting netstack vs standard networking // Start runs the SSH server
// Does all setup synchronously, then starts serving in a goroutine and returns immediately
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error { func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -139,7 +170,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
sshServer, err := s.createSSHServer(ln.Addr()) sshServer, err := s.createSSHServer(ln.Addr())
if err != nil { if err != nil {
s.cleanupOnError(ln) s.closeListener(ln)
return fmt.Errorf("create SSH server: %w", err) return fmt.Errorf("create SSH server: %w", err)
} }
@@ -154,7 +185,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return nil return nil
} }
// createListener creates a network listener based on netstack vs standard networking
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) { func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
if s.netstackNet != nil { if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr) ln, err := s.netstackNet.ListenTCPAddrPort(addr)
@@ -173,22 +203,15 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L
return ln, addr.String(), nil return ln, addr.String(), nil
} }
// closeListener safely closes a listener
func (s *Server) closeListener(ln net.Listener) { func (s *Server) closeListener(ln net.Listener) {
if ln == nil {
return
}
if err := ln.Close(); err != nil { if err := ln.Close(); err != nil {
log.Debugf("listener close error: %v", err) log.Debugf("listener close error: %v", err)
} }
} }
// cleanupOnError cleans up resources when SSH server creation fails
func (s *Server) cleanupOnError(ln net.Listener) {
if s.ifIdx == 0 || ln == nil {
return
}
s.closeListener(ln)
}
// Stop closes the SSH server // Stop closes the SSH server
func (s *Server) Stop() error { func (s *Server) Stop() error {
s.mu.Lock() s.mu.Lock()
@@ -207,28 +230,6 @@ func (s *Server) Stop() error {
return nil return nil
} }
// RemoveAuthorizedKey removes the SSH key for a peer
func (s *Server) RemoveAuthorizedKey(peer string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.authorizedKeys, peer)
}
// AddAuthorizedKey adds an SSH key for a peer
func (s *Server) AddAuthorizedKey(peer, newKey string) error {
s.mu.Lock()
defer s.mu.Unlock()
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
if err != nil {
return fmt.Errorf("parse key: %w", err)
}
s.authorizedKeys[peer] = parsedKey
return nil
}
// SetNetstackNet sets the netstack network for userspace networking // SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) { func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock() s.mu.Lock()
@@ -243,34 +244,195 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr s.wgAddress = addr
} }
// SetSocketFilter configures eBPF socket filtering for the SSH server // ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) SetSocketFilter(ifIdx int) { func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
if s.jwtValidator != nil && s.jwtExtractor != nil {
s.mu.RUnlock()
return nil
}
config := s.jwtConfig
s.mu.RUnlock()
if config == nil {
return fmt.Errorf("JWT config not set")
}
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
validator := jwt.NewValidator(
config.Issuer,
[]string{config.Audience},
config.KeysLocation,
true,
)
extractor := jwt.NewClaimsExtractor(
jwt.WithAudience(config.Audience),
)
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.ifIdx = ifIdx
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
s.jwtValidator = validator
s.jwtExtractor = extractor
log.Infof("JWT validator initialized successfully")
return nil
} }
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool { func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
s.mu.RLock() s.mu.RLock()
defer s.mu.RUnlock() jwtValidator := s.jwtValidator
jwtConfig := s.jwtConfig
s.mu.RUnlock()
for _, allowed := range s.authorizedKeys { if jwtValidator == nil {
if ssh.KeysEqual(allowed, key) { return nil, fmt.Errorf("JWT validator not initialized")
if ctx != nil { }
log.Debugf("SSH key authentication successful for user %s from %s", ctx.User(), ctx.RemoteAddr())
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
if err != nil {
if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
} }
return true
} }
return nil, fmt.Errorf("validate token: %w", err)
} }
if ctx != nil { if err := s.checkTokenAge(token, jwtConfig); err != nil {
log.Warnf("SSH key authentication failed for user %s from %s: key not authorized (type: %s, fingerprint: %s)", return nil, err
ctx.User(), ctx.RemoteAddr(), key.Type(), cryptossh.FingerprintSHA256(key))
} }
return false
return token, nil
}
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
return nil
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
s.mu.RLock()
jwtExtractor := s.jwtExtractor
s.mu.RUnlock()
if jwtExtractor == nil {
userID := extractUserID(token)
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
}
userAuth, err := jwtExtractor.ToUserAuth(token)
if err != nil {
userID := extractUserID(token)
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
}
if !s.hasSSHAccess(&userAuth) {
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
}
return &userAuth, nil
}
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
return userAuth.UserId != ""
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("decode payload: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("parse claims: %w", err)
}
return claims, nil
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true
} }
// markConnectionActivePortForward marks an SSH connection as having an active port forward
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) { func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -286,14 +448,12 @@ func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn,
} }
} }
// connectionCloseHandler cleans up connection state when SSH connections fail/close
func (s *Server) connectionCloseHandler(conn net.Conn, err error) { func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
// We can't extract the SSH connection from net.Conn directly // We can't extract the SSH connection from net.Conn directly
// Connection cleanup will happen during session cleanup or via timeout // Connection cleanup will happen during session cleanup or via timeout
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err) log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
} }
// findSessionKeyByContext finds the session key by matching SSH connection context
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey { func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
if ctx == nil { if ctx == nil {
return "unknown" return "unknown"
@@ -319,14 +479,13 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
// Return a temporary key that we'll fix up later // Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil { if ctx.User() != "" && ctx.RemoteAddr() != nil {
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String())) tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
log.Debugf("using temporary session key for port forward tracking: %s", tempKey) log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
return tempKey return tempKey
} }
return "unknown" return "unknown"
} }
// connectionValidator validates incoming connections based on source IP
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn { func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
s.mu.RLock() s.mu.RLock()
netbirdNetwork := s.wgAddress.Network netbirdNetwork := s.wgAddress.Network
@@ -340,8 +499,8 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
remoteAddr := conn.RemoteAddr() remoteAddr := conn.RemoteAddr()
tcpAddr, ok := remoteAddr.(*net.TCPAddr) tcpAddr, ok := remoteAddr.(*net.TCPAddr)
if !ok { if !ok {
log.Debugf("SSH connection from non-TCP address %s allowed", remoteAddr) log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
return conn return nil
} }
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP) remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
@@ -357,15 +516,14 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
} }
if !netbirdNetwork.Contains(remoteIP) { if !netbirdNetwork.Contains(remoteIP) {
log.Warnf("SSH connection rejected from non-NetBird IP %s (allowed range: %s)", remoteIP, netbirdNetwork) log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
return nil return nil
} }
log.Debugf("SSH connection from %s allowed", remoteIP) log.Infof("SSH connection from NetBird peer %s allowed", remoteIP)
return conn return conn
} }
// isShutdownError checks if the error is expected during normal shutdown
func isShutdownError(err error) bool { func isShutdownError(err error) bool {
if errors.Is(err, net.ErrClosed) { if errors.Is(err, net.ErrClosed) {
return true return true
@@ -379,12 +537,16 @@ func isShutdownError(err error) bool {
return false return false
} }
// createSSHServer creates and configures the SSH server
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) { func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
if err := enableUserSwitching(); err != nil { if err := enableUserSwitching(); err != nil {
log.Warnf("failed to enable user switching: %v", err) log.Warnf("failed to enable user switching: %v", err)
} }
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
if s.jwtEnabled {
serverVersion += " " + detection.JWTRequiredMarker
}
server := &ssh.Server{ server := &ssh.Server{
Addr: addr.String(), Addr: addr.String(),
Handler: s.sessionHandler, Handler: s.sessionHandler,
@@ -402,6 +564,11 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
}, },
ConnCallback: s.connectionValidator, ConnCallback: s.connectionValidator,
ConnectionFailedCallback: s.connectionCloseHandler, ConnectionFailedCallback: s.connectionCloseHandler,
Version: serverVersion,
}
if s.jwtEnabled {
server.PasswordHandler = s.passwordHandler
} }
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM) hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
@@ -413,14 +580,12 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
return server, nil return server, nil
} }
// storeRemoteForwardListener stores a remote forward listener for cleanup
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) { func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln s.remoteForwardListeners[key] = ln
} }
// removeRemoteForwardListener removes and closes a remote forward listener
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool { func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
s.mu.Lock() s.mu.Lock()
defer s.mu.Unlock() defer s.mu.Unlock()
@@ -438,7 +603,6 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
return true return true
} }
// directTCPIPHandler handles direct-tcpip channel requests for local port forwarding with privilege validation
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) { func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
var payload struct { var payload struct {
Host string Host string

View File

@@ -22,12 +22,6 @@ func TestServer_RootLoginRestriction(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
tests := []struct { tests := []struct {
name string name string
allowRoot bool allowRoot bool
@@ -117,10 +111,12 @@ func TestServer_RootLoginRestriction(t *testing.T) {
defer cleanup() defer cleanup()
// Create server with specific configuration // Create server with specific configuration
server := New(hostKey) serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(tt.allowRoot) server.SetAllowRootLogin(tt.allowRoot)
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
// Test the userNameLookup method directly // Test the userNameLookup method directly
user, err := server.userNameLookup(tt.username) user, err := server.userNameLookup(tt.username)
@@ -196,7 +192,11 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
for _, tt := range tests { for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) { t.Run(tt.name, func(t *testing.T) {
// Create server with specific configuration // Create server with specific configuration
server := New(hostKey) serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding) server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding) server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
@@ -234,17 +234,13 @@ func TestServer_PortConflictHandling(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519) hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server // Create server
server := New(hostKey) serverConfig := &Config{
server.SetAllowRootLogin(true) // Allow root login for testing HostKeyPEM: hostKey,
err = server.AddAuthorizedKey("test-peer", string(clientPubKey)) JWT: nil,
require.NoError(t, err) }
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server) serverAddr := StartTestServer(t, server)
defer func() { defer func() {
@@ -263,7 +259,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second) ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel1() defer cancel1()
client1, err := sshclient.DialInsecure(ctx1, serverAddr, currentUser.Username) client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
err := client1.Close() err := client1.Close()
@@ -274,7 +272,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second) ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2() defer cancel2()
client2, err := sshclient.DialInsecure(ctx2, serverAddr, currentUser.Username) client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err) require.NoError(t, err)
defer func() { defer func() {
err := client2.Close() err := client2.Close()

View File

@@ -7,11 +7,9 @@ import (
"net/netip" "net/netip"
"os/user" "os/user"
"runtime" "runtime"
"strings"
"testing" "testing"
"time" "time"
"github.com/gliderlabs/ssh"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh" cryptossh "golang.org/x/crypto/ssh"
@@ -19,82 +17,15 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh" nbssh "github.com/netbirdio/netbird/client/ssh"
) )
func TestServer_AddAuthorizedKey(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
keys := map[string][]byte{}
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey(peer, string(remotePubKey))
require.NoError(t, err)
keys[peer] = remotePubKey
}
for peer, remotePubKey := range keys {
k, ok := server.authorizedKeys[peer]
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(cryptossh.MarshalAuthorizedKey(k))))
}
}
func TestServer_RemoveAuthorizedKey(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
require.NoError(t, err)
server.RemoveAuthorizedKey("remotePeer")
_, ok := server.authorizedKeys["remotePeer"]
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
}
func TestServer_PubKeyHandler(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
var keys []ssh.PublicKey
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
require.NoError(t, err)
err = server.AddAuthorizedKey(peer, string(remotePubKey))
require.NoError(t, err)
keys = append(keys, remoteParsedPubKey)
}
for _, key := range keys {
accepted := server.publicKeyHandler(nil, key)
assert.True(t, accepted, "SSH key should be accepted")
}
}
func TestServer_StartStop(t *testing.T) { func TestServer_StartStop(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519) key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
server := New(key) serverConfig := &Config{
HostKeyPEM: key,
JWT: nil,
}
server := New(serverConfig)
err = server.Stop() err = server.Stop()
assert.NoError(t, err) assert.NoError(t, err)
@@ -108,15 +39,13 @@ func TestSSHServerIntegration(t *testing.T) {
// Generate client key pair // Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with random port // Create server with random port
server := New(hostKey) serverConfig := &Config{
HostKeyPEM: hostKey,
// Add client's public key as authorized JWT: nil,
err = server.AddAuthorizedKey("test-peer", string(clientPubKey)) }
require.NoError(t, err) server := New(serverConfig)
// Start server in background // Start server in background
serverAddr := "127.0.0.1:0" serverAddr := "127.0.0.1:0"
@@ -212,13 +141,13 @@ func TestSSHServerMultipleConnections(t *testing.T) {
// Generate client key pair // Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server // Create server
server := New(hostKey) serverConfig := &Config{
err = server.AddAuthorizedKey("test-peer", string(clientPubKey)) HostKeyPEM: hostKey,
require.NoError(t, err) JWT: nil,
}
server := New(serverConfig)
// Start server // Start server
serverAddr := "127.0.0.1:0" serverAddr := "127.0.0.1:0"
@@ -324,20 +253,12 @@ func TestSSHServerNoAuthMode(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
// Generate authorized key // Create server
authorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) serverConfig := &Config{
require.NoError(t, err) HostKeyPEM: hostKey,
authorizedPubKey, err := nbssh.GeneratePublicKey(authorizedPrivKey) JWT: nil,
require.NoError(t, err) }
server := New(serverConfig)
// Generate unauthorized key (different from authorized)
unauthorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server with only one authorized key
server := New(hostKey)
err = server.AddAuthorizedKey("authorized-peer", string(authorizedPubKey))
require.NoError(t, err)
// Start server // Start server
serverAddr := "127.0.0.1:0" serverAddr := "127.0.0.1:0"
@@ -377,8 +298,10 @@ func TestSSHServerNoAuthMode(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
}() }()
// Parse unauthorized private key // Generate a client private key for SSH protocol (server doesn't check it)
unauthorizedSigner, err := cryptossh.ParsePrivateKey(unauthorizedPrivKey) clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err) require.NoError(t, err)
// Parse server host key // Parse server host key
@@ -390,17 +313,17 @@ func TestSSHServerNoAuthMode(t *testing.T) {
currentUser, err := user.Current() currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test") require.NoError(t, err, "Should be able to get current user for test")
// Try to connect with unauthorized key // Try to connect with client key
config := &cryptossh.ClientConfig{ config := &cryptossh.ClientConfig{
User: currentUser.Username, User: currentUser.Username,
Auth: []cryptossh.AuthMethod{ Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(unauthorizedSigner), cryptossh.PublicKeys(clientSigner),
}, },
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey), HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second, Timeout: 3 * time.Second,
} }
// This should succeed in no-auth mode // This should succeed in no-auth mode (server doesn't verify keys)
conn, err := cryptossh.Dial("tcp", serverAddr, config) conn, err := cryptossh.Dial("tcp", serverAddr, config)
assert.NoError(t, err, "Connection should succeed in no-auth mode") assert.NoError(t, err, "Connection should succeed in no-auth mode")
if conn != nil { if conn != nil {
@@ -412,7 +335,11 @@ func TestSSHServerStartStopCycle(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
server := New(hostKey) serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
serverAddr := "127.0.0.1:0" serverAddr := "127.0.0.1:0"
// Test multiple start/stop cycles // Test multiple start/stop cycles
@@ -485,8 +412,17 @@ func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519) hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
server1 := New(hostKey) serverConfig1 := &Config{
server2 := New(hostKey) HostKeyPEM: hostKey,
JWT: nil,
}
server1 := New(serverConfig1)
serverConfig2 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server2 := New(serverConfig2)
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security") assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security") assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")

View File

@@ -35,17 +35,15 @@ func TestSSHServer_SFTPSubsystem(t *testing.T) {
// Generate client key pair // Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with SFTP enabled // Create server with SFTP enabled
server := New(hostKey) serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(true) server.SetAllowSFTP(true)
server.SetAllowRootLogin(true) // Allow root login for testing server.SetAllowRootLogin(true)
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
// Start server // Start server
serverAddr := "127.0.0.1:0" serverAddr := "127.0.0.1:0"
@@ -144,17 +142,15 @@ func TestSSHServer_SFTPDisabled(t *testing.T) {
// Generate client key pair // Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519) clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err) require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with SFTP disabled // Create server with SFTP disabled
server := New(hostKey) serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(false) server.SetAllowSFTP(false)
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
// Start server // Start server
serverAddr := "127.0.0.1:0" serverAddr := "127.0.0.1:0"
started := make(chan string, 1) started := make(chan string, 1)

View File

@@ -14,7 +14,6 @@ func StartTestServer(t *testing.T, server *Server) string {
errChan := make(chan error, 1) errChan := make(chan error, 1)
go func() { go func() {
// Get a free port
ln, err := net.Listen("tcp", "127.0.0.1:0") ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil { if err != nil {
errChan <- err errChan <- err
@@ -26,9 +25,12 @@ func StartTestServer(t *testing.T, server *Server) string {
return return
} }
started <- actualAddr
addrPort := netip.MustParseAddrPort(actualAddr) addrPort := netip.MustParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort) if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}() }()
select { select {

View File

@@ -0,0 +1,172 @@
package testutil
import (
"fmt"
"log"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
var testCreatedUsers = make(map[string]bool)
var testUsersToCleanup []string
// GetTestUsername returns an appropriate username for testing
func GetTestUsername(t *testing.T) string {
if runtime.GOOS == "windows" {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
if IsSystemAccount(currentUser.Username) {
if IsCI() {
if testUser := GetOrCreateTestUser(t); testUser != "" {
return testUser
}
} else {
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := GetOrCreateTestUser(t); testUser != "" {
return testUser
}
}
}
return currentUser.Username
}
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
return currentUser.Username
}
// IsCI checks if we're running in a CI environment
func IsCI() bool {
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// IsSystemAccount checks if the user is a system account that can't authenticate
func IsSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
// RegisterTestUserCleanup registers a test user for cleanup
func RegisterTestUserCleanup(username string) {
if !testCreatedUsers[username] {
testCreatedUsers[username] = true
testUsersToCleanup = append(testUsersToCleanup, username)
}
}
// CleanupTestUsers removes all created test users
func CleanupTestUsers() {
for _, username := range testUsersToCleanup {
RemoveWindowsTestUser(username)
}
testUsersToCleanup = nil
testCreatedUsers = make(map[string]bool)
}
// GetOrCreateTestUser creates a test user on Windows if needed
func GetOrCreateTestUser(t *testing.T) string {
testUsername := "netbird-test-user"
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
if CreateWindowsTestUser(t, testUsername) {
RegisterTestUserCleanup(testUsername)
return testUsername
}
return ""
}
// RemoveWindowsTestUser removes a local user on Windows using PowerShell
func RemoveWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// CreateWindowsTestUser creates a local user on Windows using PowerShell
func CreateWindowsTestUser(t *testing.T, username string) bool {
if runtime.GOOS != "windows" {
return false
}
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
t.Logf("Test user creation result: %s", string(output))
return true
}

View File

@@ -77,6 +77,7 @@ type Info struct {
EnableSSHSFTP bool EnableSSHSFTP bool
EnableSSHLocalPortForwarding bool EnableSSHLocalPortForwarding bool
EnableSSHRemotePortForwarding bool EnableSSHRemotePortForwarding bool
DisableSSHAuth bool
} }
func (i *Info) SetFlags( func (i *Info) SetFlags(
@@ -85,6 +86,7 @@ func (i *Info) SetFlags(
disableClientRoutes, disableServerRoutes, disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool, disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool, enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
disableSSHAuth *bool,
) { ) {
i.RosenpassEnabled = rosenpassEnabled i.RosenpassEnabled = rosenpassEnabled
i.RosenpassPermissive = rosenpassPermissive i.RosenpassPermissive = rosenpassPermissive
@@ -113,6 +115,9 @@ func (i *Info) SetFlags(
if enableSSHRemotePortForwarding != nil { if enableSSHRemotePortForwarding != nil {
i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding
} }
if disableSSHAuth != nil {
i.DisableSSHAuth = *disableSSHAuth
}
} }
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context // extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context

View File

@@ -270,6 +270,7 @@ type serviceClient struct {
sEnableSSHSFTP *widget.Check sEnableSSHSFTP *widget.Check
sEnableSSHLocalPortForward *widget.Check sEnableSSHLocalPortForward *widget.Check
sEnableSSHRemotePortForward *widget.Check sEnableSSHRemotePortForward *widget.Check
sDisableSSHAuth *widget.Check
// observable settings over corresponding iMngURL and iPreSharedKey values. // observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string managementURL string
@@ -288,6 +289,7 @@ type serviceClient struct {
enableSSHSFTP bool enableSSHSFTP bool
enableSSHLocalPortForward bool enableSSHLocalPortForward bool
enableSSHRemotePortForward bool enableSSHRemotePortForward bool
disableSSHAuth bool
connected bool connected bool
update *version.Update update *version.Update
@@ -437,6 +439,7 @@ func (s *serviceClient) showSettingsUI() {
s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil) s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil)
s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil) s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil) s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
s.wSettings.SetContent(s.getSettingsForm()) s.wSettings.SetContent(s.getSettingsForm())
s.wSettings.Resize(fyne.NewSize(600, 400)) s.wSettings.Resize(fyne.NewSize(600, 400))
@@ -597,6 +600,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
req.EnableSSHLocalPortForward = &s.sEnableSSHLocalPortForward.Checked req.EnableSSHLocalPortForward = &s.sEnableSSHLocalPortForward.Checked
req.EnableSSHRemotePortForward = &s.sEnableSSHRemotePortForward.Checked req.EnableSSHRemotePortForward = &s.sEnableSSHRemotePortForward.Checked
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
if s.iPreSharedKey.Text != censoredPreSharedKey { if s.iPreSharedKey.Text != censoredPreSharedKey {
req.OptionalPreSharedKey = &s.iPreSharedKey.Text req.OptionalPreSharedKey = &s.iPreSharedKey.Text
@@ -682,6 +686,7 @@ func (s *serviceClient) getSSHForm() *widget.Form {
{Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP}, {Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP},
{Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward}, {Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
{Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward}, {Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
{Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth},
}, },
} }
} }
@@ -690,7 +695,8 @@ func (s *serviceClient) hasSSHChanges() bool {
return s.enableSSHRoot != s.sEnableSSHRoot.Checked || return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
s.enableSSHSFTP != s.sEnableSSHSFTP.Checked || s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked || s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked ||
s.disableSSHAuth != s.sDisableSSHAuth.Checked
} }
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) { func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
@@ -1233,6 +1239,9 @@ func (s *serviceClient) getSrvConfig() {
if cfg.EnableSSHRemotePortForwarding != nil { if cfg.EnableSSHRemotePortForwarding != nil {
s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding
} }
if cfg.DisableSSHAuth != nil {
s.disableSSHAuth = *cfg.DisableSSHAuth
}
if s.showAdvancedSettings { if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL) s.iMngURL.SetText(s.managementURL)
@@ -1266,6 +1275,9 @@ func (s *serviceClient) getSrvConfig() {
if cfg.EnableSSHRemotePortForwarding != nil { if cfg.EnableSSHRemotePortForwarding != nil {
s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding) s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding)
} }
if cfg.DisableSSHAuth != nil {
s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth)
}
} }
if s.mNotifications == nil { if s.mNotifications == nil {
@@ -1348,6 +1360,9 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
if cfg.EnableSSHRemotePortForwarding { if cfg.EnableSSHRemotePortForwarding {
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
} }
if cfg.DisableSSHAuth {
config.DisableSSHAuth = &cfg.DisableSSHAuth
}
return &config return &config
} }

View File

@@ -245,6 +245,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
} }
h.client.getSrvConfig() h.client.getSrvConfig()
return nil return nil
} }

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
netbird "github.com/netbirdio/netbird/client/embed" netbird "github.com/netbirdio/netbird/client/embed"
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh" "github.com/netbirdio/netbird/client/wasm/internal/ssh"
@@ -125,10 +126,15 @@ func createSSHMethod(client *netbird.Client) js.Func {
username = args[2].String() username = args[2].String()
} }
var jwtToken string
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
jwtToken = args[3].String()
}
return createPromise(func(resolve, reject js.Value) { return createPromise(func(resolve, reject js.Value) {
sshClient := ssh.NewClient(client) sshClient := ssh.NewClient(client)
if err := sshClient.Connect(host, port, username); err != nil { if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
reject.Invoke(err.Error()) reject.Invoke(err.Error())
return return
} }
@@ -191,12 +197,43 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value {
})) }))
} }
// createDetectSSHServerMethod creates the SSH server detection method
func createDetectSSHServerMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: requires host and port")
}
host := args[0].String()
port := args[1].Int()
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
serverType, err := detectSSHServerType(ctx, client, host, port)
if err != nil {
reject.Invoke(err.Error())
return
}
resolve.Invoke(js.ValueOf(serverType.RequiresJWT()))
})
})
}
// detectSSHServerType detects SSH server type using NetBird network connection
func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) {
return sshdetection.DetectSSHServerType(ctx, client, host, port)
}
// createClientObject wraps the NetBird client in a JavaScript object // createClientObject wraps the NetBird client in a JavaScript object
func createClientObject(client *netbird.Client) js.Value { func createClientObject(client *netbird.Client) js.Value {
obj := make(map[string]interface{}) obj := make(map[string]interface{})
obj["start"] = createStartMethod(client) obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client) obj["stop"] = createStopMethod(client)
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client) obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client) obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client) obj["createRDPProxy"] = createRDPProxyMethod(client)

View File

@@ -13,6 +13,7 @@ import (
"golang.org/x/crypto/ssh" "golang.org/x/crypto/ssh"
netbird "github.com/netbirdio/netbird/client/embed" netbird "github.com/netbirdio/netbird/client/embed"
nbssh "github.com/netbirdio/netbird/client/ssh"
) )
const ( const (
@@ -45,34 +46,19 @@ func NewClient(nbClient *netbird.Client) *Client {
} }
// Connect establishes an SSH connection through NetBird network // Connect establishes an SSH connection through NetBird network
func (c *Client) Connect(host string, port int, username string) error { func (c *Client) Connect(host string, port int, username, jwtToken string) error {
addr := fmt.Sprintf("%s:%d", host, port) addr := fmt.Sprintf("%s:%d", host, port)
logrus.Infof("SSH: Connecting to %s as %s", addr, username) logrus.Infof("SSH: Connecting to %s as %s", addr, username)
var authMethods []ssh.AuthMethod authMethods, err := c.getAuthMethods(jwtToken)
nbConfig, err := c.nbClient.GetConfig()
if err != nil { if err != nil {
return fmt.Errorf("get NetBird config: %w", err) return err
} }
if nbConfig.SSHKey == "" {
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
}
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
if err != nil {
return fmt.Errorf("parse NetBird SSH private key: %w", err)
}
pubKey := signer.PublicKey()
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
authMethods = append(authMethods, ssh.PublicKeys(signer))
config := &ssh.ClientConfig{ config := &ssh.ClientConfig{
User: username, User: username,
Auth: authMethods, Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), HostKeyCallback: nbssh.CreateHostKeyCallback(c.nbClient),
Timeout: sshDialTimeout, Timeout: sshDialTimeout,
} }
@@ -96,6 +82,33 @@ func (c *Client) Connect(host string, port int, username string) error {
return nil return nil
} }
// getAuthMethods returns SSH authentication methods, preferring JWT if available
func (c *Client) getAuthMethods(jwtToken string) ([]ssh.AuthMethod, error) {
if jwtToken != "" {
logrus.Debugf("SSH: Using JWT password authentication")
return []ssh.AuthMethod{ssh.Password(jwtToken)}, nil
}
logrus.Debugf("SSH: No JWT token, using public key authentication")
nbConfig, err := c.nbClient.GetConfig()
if err != nil {
return nil, fmt.Errorf("get NetBird config: %w", err)
}
if nbConfig.SSHKey == "" {
return nil, fmt.Errorf("no NetBird SSH key available")
}
signer, err := ssh.ParsePrivateKey([]byte(nbConfig.SSHKey))
if err != nil {
return nil, fmt.Errorf("parse NetBird SSH private key: %w", err)
}
logrus.Debugf("SSH: Added public key auth")
return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
}
// StartSession starts an SSH session with PTY // StartSession starts an SSH session with PTY
func (c *Client) StartSession(cols, rows int) error { func (c *Client) StartSession(cols, rows int) error {
if c.sshClient == nil { if c.sshClient == nil {

View File

@@ -1,50 +0,0 @@
//go:build js
package ssh
import (
"crypto/x509"
"encoding/pem"
"fmt"
"strings"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
keyStr := string(keyPEM)
if !strings.Contains(keyStr, "-----BEGIN") {
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
}
signer, err := ssh.ParsePrivateKey(keyPEM)
if err == nil {
return signer, nil
}
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
block, _ := pem.Decode(keyPEM)
if block == nil {
keyPreview := string(keyPEM)
if len(keyPreview) > 100 {
keyPreview = keyPreview[:100]
}
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return ssh.NewSignerFromKey(rsaKey)
}
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return ssh.NewSignerFromKey(ecKey)
}
return nil, fmt.Errorf("parse private key: %w", err)
}
return ssh.NewSignerFromKey(key)
}

16
go.mod
View File

@@ -1,6 +1,6 @@
module github.com/netbirdio/netbird module github.com/netbirdio/netbird
go 1.23.0 go 1.23.1
require ( require (
cunicu.li/go-rosenpass v0.4.0 cunicu.li/go-rosenpass v0.4.0
@@ -17,8 +17,8 @@ require (
github.com/spf13/cobra v1.7.0 github.com/spf13/cobra v1.7.0
github.com/spf13/pflag v1.0.5 github.com/spf13/pflag v1.0.5
github.com/vishvananda/netlink v1.3.0 github.com/vishvananda/netlink v1.3.0
golang.org/x/crypto v0.40.0 golang.org/x/crypto v0.41.0
golang.org/x/sys v0.34.0 golang.org/x/sys v0.35.0
golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1 golang.zx2c4.com/wireguard v0.0.0-20230704135630-469159ecf7d1
golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6 golang.zx2c4.com/wireguard/wgctrl v0.0.0-20230429144221-925a1e7659e6
golang.zx2c4.com/wireguard/windows v0.5.3 golang.zx2c4.com/wireguard/windows v0.5.3
@@ -31,6 +31,7 @@ require (
fyne.io/fyne/v2 v2.5.3 fyne.io/fyne/v2 v2.5.3
fyne.io/systray v1.11.0 fyne.io/systray v1.11.0
github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible github.com/TheJumpCloud/jcapi-go v3.0.0+incompatible
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.36.3 github.com/aws/aws-sdk-go-v2 v1.36.3
github.com/aws/aws-sdk-go-v2/config v1.29.14 github.com/aws/aws-sdk-go-v2/config v1.29.14
github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2 github.com/aws/aws-sdk-go-v2/service/s3 v1.79.2
@@ -103,11 +104,11 @@ require (
goauthentik.io/api/v3 v3.2023051.3 goauthentik.io/api/v3 v3.2023051.3
golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842 golang.org/x/exp v0.0.0-20240506185415-9bf2ced13842
golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a golang.org/x/mobile v0.0.0-20231127183840-76ac6878050a
golang.org/x/mod v0.25.0 golang.org/x/mod v0.26.0
golang.org/x/net v0.42.0 golang.org/x/net v0.42.0
golang.org/x/oauth2 v0.28.0 golang.org/x/oauth2 v0.28.0
golang.org/x/sync v0.16.0 golang.org/x/sync v0.16.0
golang.org/x/term v0.33.0 golang.org/x/term v0.34.0
google.golang.org/api v0.177.0 google.golang.org/api v0.177.0
gopkg.in/yaml.v3 v3.0.1 gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/mysql v1.5.7 gorm.io/driver/mysql v1.5.7
@@ -128,6 +129,7 @@ require (
github.com/Microsoft/go-winio v0.6.2 // indirect github.com/Microsoft/go-winio v0.6.2 // indirect
github.com/Microsoft/hcsshim v0.12.3 // indirect github.com/Microsoft/hcsshim v0.12.3 // indirect
github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect github.com/anmitsu/go-shlex v0.0.0-20200514113438-38f4b401e2be // indirect
github.com/awnumar/memcall v0.4.0 // indirect
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 // indirect
github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect github.com/aws/aws-sdk-go-v2/credentials v1.17.67 // indirect
github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect github.com/aws/aws-sdk-go-v2/feature/ec2/imds v1.16.30 // indirect
@@ -246,9 +248,9 @@ require (
go.uber.org/mock v0.4.0 // indirect go.uber.org/mock v0.4.0 // indirect
go.uber.org/multierr v1.11.0 // indirect go.uber.org/multierr v1.11.0 // indirect
golang.org/x/image v0.18.0 // indirect golang.org/x/image v0.18.0 // indirect
golang.org/x/text v0.27.0 // indirect golang.org/x/text v0.28.0 // indirect
golang.org/x/time v0.5.0 // indirect golang.org/x/time v0.5.0 // indirect
golang.org/x/tools v0.34.0 // indirect golang.org/x/tools v0.35.0 // indirect
golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect golang.zx2c4.com/wintun v0.0.0-20230126152724-0fa3db229ce2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect google.golang.org/genproto/googleapis/rpc v0.0.0-20250707201910-8d1bb00bc6a7 // indirect
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect

28
go.sum
View File

@@ -74,6 +74,10 @@ github.com/antihax/optional v1.0.0/go.mod h1:uupD/76wgC+ih3iEmQUL+0Ugr19nfwCT1kd
github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o= github.com/armon/circbuf v0.0.0-20150827004946-bbbad097214e/go.mod h1:3U/XgcO3hCbHZ8TKRvWD2dDTCfh9M9ya+I9JpbB7O8o=
github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY= github.com/armon/go-metrics v0.0.0-20180917152333-f0300d1749da/go.mod h1:Q73ZrmVTwzkszR9V5SSuryQ31EELlFMUz1kKyl939pY=
github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8= github.com/armon/go-radix v0.0.0-20180808171621-7fddfc383310/go.mod h1:ufUuZ+zHj4x4TnLV4JWEpy2hxWSpsRywHrMgIH9cCH8=
github.com/awnumar/memcall v0.4.0 h1:B7hgZYdfH6Ot1Goaz8jGne/7i8xD4taZie/PNSFZ29g=
github.com/awnumar/memcall v0.4.0/go.mod h1:8xOx1YbfyuCg3Fy6TO8DK0kZUua3V42/goA5Ru47E8w=
github.com/awnumar/memguard v0.23.0 h1:sJ3a1/SWlcuKIQ7MV+R9p0Pvo9CWsMbGZvcZQtmc68A=
github.com/awnumar/memguard v0.23.0/go.mod h1:olVofBrsPdITtJ2HgxQKrEYEMyIBAIciVG4wNnZhW9M=
github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM= github.com/aws/aws-sdk-go-v2 v1.36.3 h1:mJoei2CxPutQVxaATCzDUjcZEjVRdpsiiXi2o38yqWM=
github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg= github.com/aws/aws-sdk-go-v2 v1.36.3/go.mod h1:LLXuLpgzEbD766Z5ECcRmi8AzSwfZItDtmABVkRLGzg=
github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs= github.com/aws/aws-sdk-go-v2/aws/protocol/eventstream v1.6.10 h1:zAybnyUQXIZ5mok5Jqwlf58/TFE7uvd3IAsa1aF9cXs=
@@ -778,8 +782,8 @@ golang.org/x/crypto v0.18.0/go.mod h1:R0j02AL6hcrfOiy9T4ZYp/rcWeMxM3L6QYxlOuEG1m
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8= golang.org/x/crypto v0.23.0/go.mod h1:CKFgDieR+mRhux2Lsu27y0fO304Db0wZe70UKqHu0v8=
golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk= golang.org/x/crypto v0.31.0/go.mod h1:kDsLvtWBEx7MV9tJOj9bnXsPbxwJQ6csT/x4KIN4Ssk=
golang.org/x/crypto v0.40.0 h1:r4x+VvoG5Fm+eJcxMaY8CQM7Lb0l1lsmjGBQ6s8BfKM= golang.org/x/crypto v0.41.0 h1:WKYxWedPGCTVVl5+WHSSrOBT0O8lx32+zxmHxijgXp4=
golang.org/x/crypto v0.40.0/go.mod h1:Qr1vMER5WyS2dfPHAlsOj01wgLbsyWtFn/aY+5+ZdxY= golang.org/x/crypto v0.41.0/go.mod h1:pO5AFd7FA68rFak7rOAGVuygIISepHftHnr8dr6+sUc=
golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190121172915-509febef88a4/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA= golang.org/x/exp v0.0.0-20190306152737-a1d7652674e8/go.mod h1:CJ0aWSM057203Lf6IL+f9T1iT9GByDxfZKAQTCR3kQA=
golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8= golang.org/x/exp v0.0.0-20190510132918-efd6b22b2522/go.mod h1:ZjyILWgesfNpC6sMxTJOJm9Kp84zZh5NQWvqDGG3Qr8=
@@ -828,8 +832,8 @@ golang.org/x/mod v0.8.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs= golang.org/x/mod v0.12.0/go.mod h1:iBbtSCu2XBx23ZKBPSOrRkjjQPZFPuis4dIYUhu/chs=
golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.15.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c= golang.org/x/mod v0.17.0/go.mod h1:hTbmBsO62+eylJbnUtE2MGJUyE7QWk4xUqPFrRgJ+7c=
golang.org/x/mod v0.25.0 h1:n7a+ZbQKQA/Ysbyb0/6IbB1H/X41mKgbhfv7AfG/44w= golang.org/x/mod v0.26.0 h1:EGMPT//Ezu+ylkCijjPc+f4Aih7sZvaAr+O3EHBxvZg=
golang.org/x/mod v0.25.0/go.mod h1:IXM97Txy2VM4PJ3gI61r1YEk/gAj6zAHN3AdZt6S9Ww= golang.org/x/mod v0.26.0/go.mod h1:/j6NAhSk8iQ723BGAUyoAcn7SlD7s15Dp9Nd/SfeaFQ=
golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180724234803-3673e40ba225/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180826012351-8a410e7b638d/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180906233101-161cd47e91fd/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
@@ -985,8 +989,8 @@ golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.20.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.28.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.34.0 h1:H5Y5sJ2L2JRdyv7ROF1he/lPdvFsd0mJHFw2ThKHxLA= golang.org/x/sys v0.35.0 h1:vz1N37gP5bs89s7He8XuIYXpyY0+QlsKmzipCbUtyxI=
golang.org/x/sys v0.34.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k= golang.org/x/sys v0.35.0/go.mod h1:BJP2sWEmIv4KK5OTEluFJCKSidICx8ciO85XgH3Ak8k=
golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE= golang.org/x/telemetry v0.0.0-20240228155512-f48c80bd79b2/go.mod h1:TeRTkGYfJXctD9OcfyVLyj2J3IxLnKwHJR8f4D8a3YE=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
@@ -999,8 +1003,8 @@ golang.org/x/term v0.16.0/go.mod h1:yn7UURbUtPyrVJPGPq404EukNFxcm/foM+bV/bfcDsY=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk=
golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY= golang.org/x/term v0.20.0/go.mod h1:8UkIAJTvZgivsXaD6/pH6U9ecQzZ45awqEOzuCvwpFY=
golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM= golang.org/x/term v0.27.0/go.mod h1:iMsnZpn0cago0GOrHO2+Y7u7JPn5AylBrcoWkElMTSM=
golang.org/x/term v0.33.0 h1:NuFncQrRcaRvVmgRkvM3j/F00gWIAlcmlB8ACEKmGIg= golang.org/x/term v0.34.0 h1:O/2T7POpk0ZZ7MAzMeWFSg6S5IpWd/RXDlM9hgM3DR4=
golang.org/x/term v0.33.0/go.mod h1:s18+ql9tYWp1IfpV9DmCtQDDSRBUjKaw9M1eAv5UeF0= golang.org/x/term v0.34.0/go.mod h1:5jC53AEywhIVebHgPVeg0mj8OD3VO9OzclacVrqpaAw=
golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.1-0.20180807135948-17ff2d5776d2/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
@@ -1017,8 +1021,8 @@ golang.org/x/text v0.13.0/go.mod h1:TvPlkZtksWOMsz7fbANvkp4WM8x/WCo/om8BMLbz+aE=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.15.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ= golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/text v0.27.0 h1:4fGWRpyh641NLlecmyl4LOe6yDdfaYNrGb2zdfo4JV4= golang.org/x/text v0.28.0 h1:rhazDwis8INMIwQ4tpjLDzUhx6RlXqZNPEM0huQojng=
golang.org/x/text v0.27.0/go.mod h1:1D28KMCvyooCX9hBiosv5Tz/+YLxj0j7XhWjpSUF7CU= golang.org/x/text v0.28.0/go.mod h1:U8nCwOR8jO/marOQ0QbDiOngZVEBB7MAiitBuMjXiNU=
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20190308202827-9d24e82272b4/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ= golang.org/x/time v0.0.0-20191024005414-555d28b269f0/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
@@ -1083,8 +1087,8 @@ golang.org/x/tools v0.1.12/go.mod h1:hNGJHUnrk76NpqgfD5Aqm5Crs+Hm0VOH/i9J2+nxYbc
golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU= golang.org/x/tools v0.6.0/go.mod h1:Xwgl3UAJ/d3gWutnCtw505GrjyAbvKui8lOU390QaIU=
golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58= golang.org/x/tools v0.13.0/go.mod h1:HvlwmtVNQAhOuCjW7xxvovg8wbNq7LwfXh/k7wXUl58=
golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk= golang.org/x/tools v0.21.1-0.20240508182429-e35e4ccd0d2d/go.mod h1:aiJjzUbINMkxbQROHiO6hDPo2LHcIPhhQsa9DLh0yGk=
golang.org/x/tools v0.34.0 h1:qIpSLOxeCYGg9TrcJokLBG4KFA6d795g0xkBkiESGlo= golang.org/x/tools v0.35.0 h1:mBffYraMEf7aa0sB+NuKnuCy8qI/9Bughn8dC2Gu5r0=
golang.org/x/tools v0.34.0/go.mod h1:pAP9OwEaY1CAW3HOmg3hLZC5Z0CCmzjAF2UQMSqNARg= golang.org/x/tools v0.35.0/go.mod h1:NKdj5HkL/73byiZSJjqJgKn3ep7KjFkBOkR/Hps3VPw=
golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20190717185122-a985d3407aa7/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191011141410-1b5146add898/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0= golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"net/netip" "net/netip"
"net/url"
"os" "os"
"strings" "strings"
"sync" "sync"
@@ -21,6 +22,7 @@ import (
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
integrationsConfig "github.com/netbirdio/management-integrations/integrations/config" integrationsConfig "github.com/netbirdio/management-integrations/integrations/config"
nbconfig "github.com/netbirdio/netbird/management/internals/server/config" nbconfig "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/server/peers/ephemeral" "github.com/netbirdio/netbird/management/server/peers/ephemeral"
@@ -588,7 +590,7 @@ func (s *GRPCServer) prepareLoginResponse(ctx context.Context, peer *nbpeer.Peer
// if peer has reached this point then it has logged in // if peer has reached this point then it has logged in
loginResp := &proto.LoginResponse{ loginResp := &proto.LoginResponse{
NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil), NetbirdConfig: toNetbirdConfig(s.config, nil, relayToken, nil),
PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings), PeerConfig: toPeerConfig(peer, netMap.Network, s.accountManager.GetDNSDomain(settings), settings, s.config),
Checks: toProtocolChecks(ctx, postureChecks), Checks: toProtocolChecks(ctx, postureChecks),
} }
@@ -703,12 +705,21 @@ func toNetbirdConfig(config *nbconfig.Config, turnCredentials *Token, relayToken
return nbConfig return nbConfig
} }
func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings) *proto.PeerConfig { func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, settings *types.Settings, config *nbconfig.Config) *proto.PeerConfig {
netmask, _ := network.Net.Mask.Size() netmask, _ := network.Net.Mask.Size()
fqdn := peer.FQDN(dnsName) fqdn := peer.FQDN(dnsName)
sshConfig := &proto.SSHConfig{
SshEnabled: peer.SSHEnabled,
}
if peer.SSHEnabled {
sshConfig.JwtConfig = buildJWTConfig(config)
}
return &proto.PeerConfig{ return &proto.PeerConfig{
Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask), // take it from the network Address: fmt.Sprintf("%s/%d", peer.IP.String(), netmask),
SshConfig: &proto.SSHConfig{SshEnabled: peer.SSHEnabled}, SshConfig: sshConfig,
Fqdn: fqdn, Fqdn: fqdn,
RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled, RoutingPeerDnsResolutionEnabled: settings.RoutingPeerDNSResolutionEnabled,
LazyConnectionEnabled: settings.LazyConnectionEnabled, LazyConnectionEnabled: settings.LazyConnectionEnabled,
@@ -717,7 +728,7 @@ func toPeerConfig(peer *nbpeer.Peer, network *types.Network, dnsName string, set
func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse { func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.Peer, turnCredentials *Token, relayCredentials *Token, networkMap *types.NetworkMap, dnsName string, checks []*posture.Checks, dnsCache *DNSConfigCache, settings *types.Settings, extraSettings *types.ExtraSettings, peerGroups []string, dnsFwdPort int64) *proto.SyncResponse {
response := &proto.SyncResponse{ response := &proto.SyncResponse{
PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings), PeerConfig: toPeerConfig(peer, networkMap.Network, dnsName, settings, config),
NetworkMap: &proto.NetworkMap{ NetworkMap: &proto.NetworkMap{
Serial: networkMap.Network.CurrentSerial(), Serial: networkMap.Network.CurrentSerial(),
Routes: toProtocolRoutes(networkMap.Routes), Routes: toProtocolRoutes(networkMap.Routes),
@@ -760,6 +771,55 @@ func toSyncResponse(ctx context.Context, config *nbconfig.Config, peer *nbpeer.P
return response return response
} }
// buildJWTConfig constructs JWT configuration for SSH servers from management server config
func buildJWTConfig(config *nbconfig.Config) *proto.JWTConfig {
if config == nil {
return nil
}
if config.HttpConfig == nil || config.HttpConfig.AuthAudience == "" {
return nil
}
var tokenEndpoint string
if config.DeviceAuthorizationFlow != nil {
tokenEndpoint = config.DeviceAuthorizationFlow.ProviderConfig.TokenEndpoint
}
issuer := deriveIssuerFromTokenEndpoint(tokenEndpoint)
if issuer == "" && config.HttpConfig.AuthIssuer != "" {
issuer = config.HttpConfig.AuthIssuer
}
if issuer == "" {
return nil
}
keysLocation := config.HttpConfig.AuthKeysLocation
if keysLocation == "" {
keysLocation = strings.TrimSuffix(issuer, "/") + "/.well-known/jwks.json"
}
return &proto.JWTConfig{
Issuer: issuer,
Audience: config.HttpConfig.AuthAudience,
KeysLocation: keysLocation,
}
}
// deriveIssuerFromTokenEndpoint extracts the issuer URL from a token endpoint
func deriveIssuerFromTokenEndpoint(tokenEndpoint string) string {
if tokenEndpoint == "" {
return ""
}
u, err := url.Parse(tokenEndpoint)
if err != nil {
return ""
}
return fmt.Sprintf("%s://%s/", u.Scheme, u.Host)
}
func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig { func appendRemotePeerConfig(dst []*proto.RemotePeerConfig, peers []*nbpeer.Peer, dnsName string) []*proto.RemotePeerConfig {
for _, rPeer := range peers { for _, rPeer := range peers {
dst = append(dst, &proto.RemotePeerConfig{ dst = append(dst, &proto.RemotePeerConfig{

File diff suppressed because it is too large Load Diff

View File

@@ -151,6 +151,7 @@ message Flags {
bool enableSSHSFTP = 12; bool enableSSHSFTP = 12;
bool enableSSHLocalPortForwarding = 13; bool enableSSHLocalPortForwarding = 13;
bool enableSSHRemotePortForwarding = 14; bool enableSSHRemotePortForwarding = 14;
bool disableSSHAuth = 15;
} }
// PeerSystemMeta is machine meta data like OS and version. // PeerSystemMeta is machine meta data like OS and version.
@@ -207,6 +208,8 @@ message NetbirdConfig {
RelayConfig relay = 4; RelayConfig relay = 4;
FlowConfig flow = 5; FlowConfig flow = 5;
JWTConfig jwt = 6;
} }
// HostConfig describes connection properties of some server (e.g. STUN, Signal, Management) // HostConfig describes connection properties of some server (e.g. STUN, Signal, Management)
@@ -245,6 +248,14 @@ message FlowConfig {
bool dnsCollection = 8; bool dnsCollection = 8;
} }
// JWTConfig represents JWT authentication configuration
message JWTConfig {
string issuer = 1;
string audience = 2;
string keysLocation = 3;
int64 maxTokenAge = 4;
}
// ProtectedHostConfig is similar to HostConfig but has additional user and password // ProtectedHostConfig is similar to HostConfig but has additional user and password
// Mostly used for TURN servers // Mostly used for TURN servers
message ProtectedHostConfig { message ProtectedHostConfig {
@@ -340,6 +351,8 @@ message SSHConfig {
// sshPubKey is a SSH public key of a peer to be added to authorized_hosts. // sshPubKey is a SSH public key of a peer to be added to authorized_hosts.
// This property should be ignore if SSHConfig comes from PeerConfig. // This property should be ignore if SSHConfig comes from PeerConfig.
bytes sshPubKey = 2; bytes sshPubKey = 2;
JWTConfig jwtConfig = 3;
} }
// DeviceAuthorizationFlowRequest empty struct for future expansion // DeviceAuthorizationFlowRequest empty struct for future expansion