Add ssh authenatication with jwt (#4550)

This commit is contained in:
Viktor Liu
2025-10-07 23:38:27 +02:00
committed by GitHub
parent 7e0bbaaa3c
commit d9efe4e944
50 changed files with 4429 additions and 2336 deletions

View File

@@ -17,9 +17,9 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/client/net"
)
// ConnectionListener export internal Listener for mobile

View File

@@ -5,10 +5,12 @@ import (
"errors"
"flag"
"fmt"
"net"
"os"
"os/signal"
"os/user"
"slices"
"strconv"
"strings"
"syscall"
@@ -16,6 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal"
sshclient "github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
sshproxy "github.com/netbirdio/netbird/client/ssh/proxy"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/util"
)
@@ -29,6 +33,7 @@ const (
enableSSHSFTPFlag = "enable-ssh-sftp"
enableSSHLocalPortForwardFlag = "enable-ssh-local-port-forwarding"
enableSSHRemotePortForwardFlag = "enable-ssh-remote-port-forwarding"
disableSSHAuthFlag = "disable-ssh-auth"
)
var (
@@ -41,6 +46,7 @@ var (
strictHostKeyChecking bool
knownHostsFile string
identityFile string
skipCachedToken bool
)
var (
@@ -49,6 +55,7 @@ var (
enableSSHSFTP bool
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
)
func init() {
@@ -57,6 +64,7 @@ func init() {
upCmd.PersistentFlags().BoolVar(&enableSSHSFTP, enableSSHSFTPFlag, false, "Enable SFTP subsystem for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHLocalPortForward, enableSSHLocalPortForwardFlag, false, "Enable local port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&enableSSHRemotePortForward, enableSSHRemotePortForwardFlag, false, "Enable remote port forwarding for SSH server")
upCmd.PersistentFlags().BoolVar(&disableSSHAuth, disableSSHAuthFlag, false, "Disable SSH authentication")
sshCmd.PersistentFlags().IntVarP(&port, "port", "p", sshserver.DefaultSSHPort, "Remote SSH port")
sshCmd.PersistentFlags().StringVarP(&username, "user", "u", "", sshUsernameDesc)
@@ -64,11 +72,14 @@ func init() {
sshCmd.PersistentFlags().BoolVar(&strictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking (default: true)")
sshCmd.PersistentFlags().StringVarP(&knownHostsFile, "known-hosts", "o", "", "Path to known_hosts file (default: ~/.ssh/known_hosts)")
sshCmd.PersistentFlags().StringVarP(&identityFile, "identity", "i", "", "Path to SSH private key file")
sshCmd.PersistentFlags().BoolVar(&skipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
sshCmd.PersistentFlags().StringArrayP("L", "L", []string{}, "Local port forwarding [bind_address:]port:host:hostport")
sshCmd.PersistentFlags().StringArrayP("R", "R", []string{}, "Remote port forwarding [bind_address:]port:host:hostport")
sshCmd.AddCommand(sshSftpCmd)
sshCmd.AddCommand(sshProxyCmd)
sshCmd.AddCommand(sshDetectCmd)
}
var sshCmd = &cobra.Command{
@@ -335,31 +346,51 @@ func parseSpacedFormat(arg string, args []string, currentIndex int, flagHandlers
}
// createSSHFlagSet creates and configures the flag set for SSH command parsing
func createSSHFlagSet() (*flag.FlagSet, *int, *string, *string, *bool, *string, *string, *string, *string) {
// sshFlags contains all SSH-related flags and parameters
type sshFlags struct {
Port int
Username string
Login string
StrictHostKeyChecking bool
KnownHostsFile string
IdentityFile string
SkipCachedToken bool
ConfigPath string
LogLevel string
LocalForwards []string
RemoteForwards []string
Host string
Command string
}
func createSSHFlagSet() (*flag.FlagSet, *sshFlags) {
defaultConfigPath := getEnvOrDefault("CONFIG", configPath)
defaultLogLevel := getEnvOrDefault("LOG_LEVEL", logLevel)
fs := flag.NewFlagSet("ssh-flags", flag.ContinueOnError)
fs.SetOutput(nil)
portFlag := fs.Int("p", sshserver.DefaultSSHPort, "SSH port")
flags := &sshFlags{}
fs.IntVar(&flags.Port, "p", sshserver.DefaultSSHPort, "SSH port")
fs.Int("port", sshserver.DefaultSSHPort, "SSH port")
userFlag := fs.String("u", "", sshUsernameDesc)
fs.StringVar(&flags.Username, "u", "", sshUsernameDesc)
fs.String("user", "", sshUsernameDesc)
loginFlag := fs.String("login", "", sshUsernameDesc+" (alias for --user)")
fs.StringVar(&flags.Login, "login", "", sshUsernameDesc+" (alias for --user)")
strictHostKeyCheckingFlag := fs.Bool("strict-host-key-checking", true, "Enable strict host key checking")
knownHostsFlag := fs.String("o", "", "Path to known_hosts file")
fs.BoolVar(&flags.StrictHostKeyChecking, "strict-host-key-checking", true, "Enable strict host key checking")
fs.StringVar(&flags.KnownHostsFile, "o", "", "Path to known_hosts file")
fs.String("known-hosts", "", "Path to known_hosts file")
identityFlag := fs.String("i", "", "Path to SSH private key file")
fs.StringVar(&flags.IdentityFile, "i", "", "Path to SSH private key file")
fs.String("identity", "", "Path to SSH private key file")
fs.BoolVar(&flags.SkipCachedToken, "no-cache", false, "Skip cached JWT token and force fresh authentication")
configFlag := fs.String("c", defaultConfigPath, "Netbird config file location")
fs.StringVar(&flags.ConfigPath, "c", defaultConfigPath, "Netbird config file location")
fs.String("config", defaultConfigPath, "Netbird config file location")
logLevelFlag := fs.String("l", defaultLogLevel, "sets Netbird log level")
fs.StringVar(&flags.LogLevel, "l", defaultLogLevel, "sets Netbird log level")
fs.String("log-level", defaultLogLevel, "sets Netbird log level")
return fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag
return fs, flags
}
func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
@@ -375,7 +406,7 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
filteredArgs, localForwardFlags, remoteForwardFlags := parseCustomSSHFlags(args)
fs, portFlag, userFlag, loginFlag, strictHostKeyCheckingFlag, knownHostsFlag, identityFlag, configFlag, logLevelFlag := createSSHFlagSet()
fs, flags := createSSHFlagSet()
if err := fs.Parse(filteredArgs); err != nil {
return parseHostnameAndCommand(filteredArgs)
@@ -386,22 +417,23 @@ func validateSSHArgsWithoutFlagParsing(_ *cobra.Command, args []string) error {
return errors.New(hostArgumentRequired)
}
port = *portFlag
if *userFlag != "" {
username = *userFlag
} else if *loginFlag != "" {
username = *loginFlag
port = flags.Port
if flags.Username != "" {
username = flags.Username
} else if flags.Login != "" {
username = flags.Login
}
strictHostKeyChecking = *strictHostKeyCheckingFlag
knownHostsFile = *knownHostsFlag
identityFile = *identityFlag
strictHostKeyChecking = flags.StrictHostKeyChecking
knownHostsFile = flags.KnownHostsFile
identityFile = flags.IdentityFile
skipCachedToken = flags.SkipCachedToken
if *configFlag != getEnvOrDefault("CONFIG", configPath) {
configPath = *configFlag
if flags.ConfigPath != getEnvOrDefault("CONFIG", configPath) {
configPath = flags.ConfigPath
}
if *logLevelFlag != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = *logLevelFlag
if flags.LogLevel != getEnvOrDefault("LOG_LEVEL", logLevel) {
logLevel = flags.LogLevel
}
localForwards = localForwardFlags
@@ -449,30 +481,20 @@ func parseHostnameAndCommand(args []string) error {
func runSSH(ctx context.Context, addr string, cmd *cobra.Command) error {
target := fmt.Sprintf("%s:%d", addr, port)
var c *sshclient.Client
var err error
if strictHostKeyChecking {
c, err = sshclient.DialWithOptions(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
})
} else {
c, err = sshclient.DialInsecure(ctx, target, username)
}
c, err := sshclient.Dial(ctx, target, username, sshclient.DialOptions{
KnownHostsFile: knownHostsFile,
IdentityFile: identityFile,
DaemonAddr: daemonAddr,
SkipCachedToken: skipCachedToken,
InsecureSkipVerify: !strictHostKeyChecking,
})
if err != nil {
cmd.Printf("Failed to connect to %s@%s\n", username, target)
cmd.Printf("\nTroubleshooting steps:\n")
cmd.Printf(" 1. Check peer connectivity: netbird status\n")
cmd.Printf(" 1. Check peer connectivity: netbird status -d\n")
cmd.Printf(" 2. Verify SSH server is enabled on the peer\n")
cmd.Printf(" 3. Ensure correct hostname/IP is used\n")
if strictHostKeyChecking {
cmd.Printf(" 4. Try --strict-host-key-checking=false to bypass host key verification\n")
}
cmd.Printf("\n")
return fmt.Errorf("dial %s: %w", target, err)
}
@@ -665,3 +687,65 @@ func normalizeLocalHost(host string) string {
}
return host
}
var sshProxyCmd = &cobra.Command{
Use: "proxy <host> <port>",
Short: "Internal SSH proxy for native SSH client integration",
Long: "Internal command used by SSH ProxyCommand to handle JWT authentication",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshProxyFn,
}
func sshProxyFn(cmd *cobra.Command, args []string) error {
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
return fmt.Errorf("invalid port: %s", portStr)
}
proxy, err := sshproxy.New(daemonAddr, host, port, cmd.ErrOrStderr())
if err != nil {
return fmt.Errorf("create SSH proxy: %w", err)
}
if err := proxy.Connect(cmd.Context()); err != nil {
return fmt.Errorf("SSH proxy: %w", err)
}
return nil
}
var sshDetectCmd = &cobra.Command{
Use: "detect <host> <port>",
Short: "Detect if a host is running NetBird SSH",
Long: "Internal command used by SSH Match exec to detect NetBird SSH servers. Exit codes: 0=JWT, 1=no-JWT, 2=regular SSH",
Hidden: true,
Args: cobra.ExactArgs(2),
RunE: sshDetectFn,
}
func sshDetectFn(cmd *cobra.Command, args []string) error {
if err := util.InitLog(logLevel, "console"); err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
host := args[0]
portStr := args[1]
port, err := strconv.Atoi(portStr)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(cmd.Context(), dialer, host, port)
if err != nil {
os.Exit(detection.ServerTypeRegular.ExitCode())
}
os.Exit(serverType.ExitCode())
return nil
}

View File

@@ -360,6 +360,9 @@ func setupSetConfigReq(customDNSAddressConverted []byte, cmd *cobra.Command, pro
if cmd.Flag(enableSSHRemotePortForwardFlag).Changed {
req.EnableSSHRemotePortForward = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
req.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
log.Errorf("parse interface name: %v", err)
@@ -460,6 +463,10 @@ func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command, configFil
ic.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
ic.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(interfaceNameFlag).Changed {
if err := parseInterfaceName(interfaceName); err != nil {
return nil, err
@@ -576,6 +583,10 @@ func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte
loginRequest.EnableSSHRemotePortForwarding = &enableSSHRemotePortForward
}
if cmd.Flag(disableSSHAuthFlag).Changed {
loginRequest.DisableSSHAuth = &disableSSHAuth
}
if cmd.Flag(disableAutoConnectFlag).Changed {
loginRequest.DisableAutoConnect = &autoConnectDisabled
}

View File

@@ -18,12 +18,16 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
sshcommon "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
)
var ErrClientAlreadyStarted = errors.New("client already started")
var ErrClientNotStarted = errors.New("client not started")
var ErrConfigNotInitialized = errors.New("config not initialized")
var (
ErrClientAlreadyStarted = errors.New("client already started")
ErrClientNotStarted = errors.New("client not started")
ErrEngineNotStarted = errors.New("engine not started")
ErrConfigNotInitialized = errors.New("config not initialized")
)
// Client manages a netbird embedded client instance.
type Client struct {
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
// Dial dials a network address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, ErrClientNotStarted
}
c.mu.Unlock()
engine := connect.Engine()
if engine == nil {
return nil, errors.New("engine not started")
engine, err := c.getEngine()
if err != nil {
return nil, err
}
nsnet, err := engine.GetNet()
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
return nsnet.DialContext(ctx, network, address)
}
// DialContext dials a network address in the netbird network with context
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
return c.Dial(ctx, network, address)
}
// ListenTCP listens on the given address in the netbird network.
// Not applicable if the userspace networking mode is disabled.
func (c *Client) ListenTCP(address string) (net.Listener, error) {
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
}
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
engine, err := c.getEngine()
if err != nil {
return err
}
storedKey, found := engine.GetPeerSSHKey(peerAddress)
if !found {
return sshcommon.ErrPeerNotFound
}
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available.
func (c *Client) getEngine() (*internal.Engine, error) {
c.mu.Lock()
connect := c.connect
if connect == nil {
c.mu.Unlock()
return nil, netip.Addr{}, errors.New("client not started")
}
c.mu.Unlock()
if connect == nil {
return nil, ErrClientNotStarted
}
engine := connect.Engine()
if engine == nil {
return nil, netip.Addr{}, errors.New("engine not started")
return nil, ErrEngineNotStarted
}
return engine, nil
}
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
engine, err := c.getEngine()
if err != nil {
return nil, netip.Addr{}, err
}
addr, err := engine.Address()

View File

@@ -87,7 +87,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules := d.squashAcceptRules(networkMap)
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
// we have old version of management without rules handling, we should allow all traffic
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
@@ -350,7 +349,7 @@ func (d *DefaultManager) getPeerRuleID(
//
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(networkMap *mgmProto.NetworkMap, ) []*mgmProto.FirewallRule {
func (d *DefaultManager) squashAcceptRules(networkMap *mgmProto.NetworkMap) []*mgmProto.FirewallRule {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {

View File

@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
@@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)
@@ -437,6 +437,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
DisableSSHAuth: config.DisableSSHAuth,
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
@@ -527,6 +528,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -49,6 +49,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -117,6 +118,7 @@ type EngineConfig struct {
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DNSRouteInterval time.Duration
@@ -264,6 +266,7 @@ func NewEngine(
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
@@ -676,14 +679,10 @@ func (e *Engine) removeAllPeers() error {
return nil
}
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
// removePeer closes an existing peer connection and removes a peer
func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey)
if e.sshServer != nil {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
e.connMgr.RemovePeerConn(peerKey)
err := e.statusRecorder.RemovePeer(peerKey)
@@ -859,6 +858,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -920,6 +920,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
@@ -1074,24 +1075,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys
if e.sshServer != nil {
for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
if err != nil {
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
}
}
}
}
// update peer SSH host keys in status recorder for daemon API access
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
// update SSH client known_hosts with peer host keys for OpenSSH client
if err := e.updateSSHKnownHosts(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH known_hosts: %v", err)
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
}
@@ -1480,6 +1467,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(info)

View File

@@ -5,10 +5,8 @@ import (
"errors"
"fmt"
"net/netip"
"runtime"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
@@ -21,9 +19,6 @@ import (
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
RemoveAuthorizedKey(peer string)
AddAuthorizedKey(peer, newKey string) error
SetSocketFilter(ifIdx int)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -44,22 +39,6 @@ func (e *Engine) setupSSHPortRedirection() error {
return nil
}
func (e *Engine) setupSSHSocketFilter(server sshServer) error {
if runtime.GOOS != "linux" {
return nil
}
netInterface := e.wgInterface.ToInterface()
if netInterface == nil {
return errors.New("failed to get WireGuard network interface")
}
server.SetSocketFilter(netInterface.Index)
log.Debugf("SSH socket filter configured for interface %s (index: %d)", netInterface.Name, netInterface.Index)
return nil
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Info("SSH server is disabled because inbound connections are blocked")
@@ -83,66 +62,76 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
return nil
}
return e.startSSHServer()
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
return e.startSSHServer(nil)
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}
return e.startSSHServer(jwtConfig)
}
return errors.New("SSH server requires valid JWT configuration")
}
// updateSSHKnownHosts updates the SSH known_hosts file with peer host keys for OpenSSH client
func (e *Engine) updateSSHKnownHosts(remotePeers []*mgmProto.RemotePeerConfig) error {
peerKeys := e.extractPeerHostKeys(remotePeers)
if len(peerKeys) == 0 {
log.Debug("no SSH-enabled peers found, skipping known_hosts update")
// updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update")
return nil
}
if err := e.updateKnownHostsFile(peerKeys); err != nil {
return err
configMgr := sshconfig.New()
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
return nil // Don't fail engine startup on SSH config issues
}
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
SSHConfigDir: configMgr.GetSSHConfigDir(),
SSHConfigFile: configMgr.GetSSHConfigFile(),
}); err != nil {
log.Warnf("failed to update SSH config state: %v", err)
}
e.updateSSHClientConfig(peerKeys)
log.Debugf("updated SSH known_hosts with %d peer host keys", len(peerKeys))
return nil
}
// extractPeerHostKeys extracts SSH host keys from peer configurations
func (e *Engine) extractPeerHostKeys(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerHostKey {
var peerKeys []sshconfig.PeerHostKey
// extractPeerSSHInfo extracts SSH information from peer configurations
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
var peerInfo []sshconfig.PeerSSHInfo
for _, peerConfig := range remotePeers {
peerHostKey, ok := e.parsePeerHostKey(peerConfig)
if ok {
peerKeys = append(peerKeys, peerHostKey)
if peerConfig.GetSshConfig() == nil {
continue
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
continue
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
})
}
return peerKeys
}
// parsePeerHostKey parses a single peer's SSH host key configuration
func (e *Engine) parsePeerHostKey(peerConfig *mgmProto.RemotePeerConfig) (sshconfig.PeerHostKey, bool) {
if peerConfig.GetSshConfig() == nil {
return sshconfig.PeerHostKey{}, false
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
return sshconfig.PeerHostKey{}, false
}
hostKey, _, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyBytes)
if err != nil {
log.Warnf("failed to parse SSH public key for peer %s: %v", peerConfig.GetWgPubKey(), err)
return sshconfig.PeerHostKey{}, false
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
return sshconfig.PeerHostKey{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
HostKey: hostKey,
}, true
return peerInfo
}
// extractPeerIP extracts IP address from peer's allowed IPs
@@ -171,25 +160,6 @@ func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
return ""
}
// updateKnownHostsFile updates the SSH known_hosts file
func (e *Engine) updateKnownHostsFile(peerKeys []sshconfig.PeerHostKey) error {
configMgr := sshconfig.NewManager()
if err := configMgr.UpdatePeerHostKeys(peerKeys); err != nil {
return fmt.Errorf("update peer host keys: %w", err)
}
return nil
}
// updateSSHClientConfig updates SSH client configuration with peer hostnames
func (e *Engine) updateSSHClientConfig(peerKeys []sshconfig.PeerHostKey) {
configMgr := sshconfig.NewManager()
if err := configMgr.SetupSSHClientConfig(peerKeys); err != nil {
log.Warnf("failed to update SSH client config with peer hostnames: %v", err)
} else {
log.Debugf("updated SSH client config with %d peer hostnames", len(peerKeys))
}
}
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
for _, peerConfig := range remotePeers {
@@ -210,30 +180,51 @@ func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig)
log.Debugf("updated peer SSH host keys for daemon API access")
}
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
e.syncMsgMux.Lock()
statusRecorder := e.statusRecorder
e.syncMsgMux.Unlock()
if statusRecorder == nil {
return nil, false
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if len(peerState.SSHHostKey) > 0 {
return peerState.SSHHostKey, true
}
return nil, false
}
}
return nil, false
}
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
configMgr := sshconfig.NewManager()
configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil {
log.Warnf("failed to remove SSH client config: %v", err)
} else {
log.Debugf("SSH client config cleanup completed")
}
if err := configMgr.RemoveKnownHostsFile(); err != nil {
log.Warnf("failed to remove SSH known_hosts: %v", err)
} else {
log.Debugf("SSH known_hosts cleanup completed")
}
}
// startSSHServer initializes and starts the SSH server with proper configuration.
func (e *Engine) startSSHServer() error {
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
server := sshserver.New(e.config.SSHKey)
serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
@@ -259,15 +250,10 @@ func (e *Engine) startSSHServer() error {
log.Warnf("failed to setup SSH port redirection: %v", err)
}
if err := e.setupSSHSocketFilter(server); err != nil {
return fmt.Errorf("set socket filter: %w", err)
}
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
return nil
}

View File

@@ -281,7 +281,15 @@ func TestEngine_SSH(t *testing.T) {
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}

View File

@@ -128,6 +128,7 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
@@ -158,6 +159,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -21,9 +21,9 @@ import (
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
)
const eventQueueSize = 10

View File

@@ -54,6 +54,7 @@ type ConfigInput struct {
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
@@ -102,6 +103,7 @@ type Config struct {
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -423,6 +425,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
if *input.DisableSSHAuth {
log.Infof("disabling SSH authentication")
} else {
log.Infof("enabling SSH authentication")
}
config.DisableSSHAuth = input.DisableSSHAuth
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())

View File

@@ -18,8 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (

View File

@@ -20,8 +20,8 @@ import (
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
// ConnectionListener export internal Listener for mobile

View File

@@ -283,6 +283,7 @@ type LoginRequest struct {
EnableSSHSFTP *bool `protobuf:"varint,34,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
EnableSSHLocalPortForwarding *bool `protobuf:"varint,35,opt,name=enableSSHLocalPortForwarding,proto3,oneof" json:"enableSSHLocalPortForwarding,omitempty"`
EnableSSHRemotePortForwarding *bool `protobuf:"varint,36,opt,name=enableSSHRemotePortForwarding,proto3,oneof" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,37,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -570,6 +571,13 @@ func (x *LoginRequest) GetEnableSSHRemotePortForwarding() bool {
return false
}
func (x *LoginRequest) GetDisableSSHAuth() bool {
if x != nil && x.DisableSSHAuth != nil {
return *x.DisableSSHAuth
}
return false
}
type LoginResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
NeedsSSOLogin bool `protobuf:"varint,1,opt,name=needsSSOLogin,proto3" json:"needsSSOLogin,omitempty"`
@@ -1100,6 +1108,7 @@ type GetConfigResponse struct {
EnableSSHSFTP bool `protobuf:"varint,24,opt,name=enableSSHSFTP,proto3" json:"enableSSHSFTP,omitempty"`
EnableSSHLocalPortForwarding bool `protobuf:"varint,22,opt,name=enableSSHLocalPortForwarding,proto3" json:"enableSSHLocalPortForwarding,omitempty"`
EnableSSHRemotePortForwarding bool `protobuf:"varint,23,opt,name=enableSSHRemotePortForwarding,proto3" json:"enableSSHRemotePortForwarding,omitempty"`
DisableSSHAuth bool `protobuf:"varint,25,opt,name=disableSSHAuth,proto3" json:"disableSSHAuth,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -1302,6 +1311,13 @@ func (x *GetConfigResponse) GetEnableSSHRemotePortForwarding() bool {
return false
}
func (x *GetConfigResponse) GetDisableSSHAuth() bool {
if x != nil {
return x.DisableSSHAuth
}
return false
}
// PeerState contains the latest state of a peer
type PeerState struct {
state protoimpl.MessageState `protogen:"open.v1"`
@@ -3781,6 +3797,7 @@ type SetConfigRequest struct {
EnableSSHSFTP *bool `protobuf:"varint,30,opt,name=enableSSHSFTP,proto3,oneof" json:"enableSSHSFTP,omitempty"`
EnableSSHLocalPortForward *bool `protobuf:"varint,31,opt,name=enableSSHLocalPortForward,proto3,oneof" json:"enableSSHLocalPortForward,omitempty"`
EnableSSHRemotePortForward *bool `protobuf:"varint,32,opt,name=enableSSHRemotePortForward,proto3,oneof" json:"enableSSHRemotePortForward,omitempty"`
DisableSSHAuth *bool `protobuf:"varint,33,opt,name=disableSSHAuth,proto3,oneof" json:"disableSSHAuth,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
@@ -4039,6 +4056,13 @@ func (x *SetConfigRequest) GetEnableSSHRemotePortForward() bool {
return false
}
func (x *SetConfigRequest) GetDisableSSHAuth() bool {
if x != nil && x.DisableSSHAuth != nil {
return *x.DisableSSHAuth
}
return false
}
type SetConfigResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
@@ -4774,6 +4798,262 @@ func (x *GetPeerSSHHostKeyResponse) GetFound() bool {
return false
}
// RequestJWTAuthRequest for initiating JWT authentication flow
type RequestJWTAuthRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RequestJWTAuthRequest) Reset() {
*x = RequestJWTAuthRequest{}
mi := &file_daemon_proto_msgTypes[71]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RequestJWTAuthRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RequestJWTAuthRequest) ProtoMessage() {}
func (x *RequestJWTAuthRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[71]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RequestJWTAuthRequest.ProtoReflect.Descriptor instead.
func (*RequestJWTAuthRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{71}
}
// RequestJWTAuthResponse contains authentication flow information
type RequestJWTAuthResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// verification URI for user authentication
VerificationURI string `protobuf:"bytes,1,opt,name=verificationURI,proto3" json:"verificationURI,omitempty"`
// complete verification URI (with embedded user code)
VerificationURIComplete string `protobuf:"bytes,2,opt,name=verificationURIComplete,proto3" json:"verificationURIComplete,omitempty"`
// user code to enter on verification URI
UserCode string `protobuf:"bytes,3,opt,name=userCode,proto3" json:"userCode,omitempty"`
// device code for polling
DeviceCode string `protobuf:"bytes,4,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
// expiration time in seconds
ExpiresIn int64 `protobuf:"varint,5,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
// if a cached token is available, it will be returned here
CachedToken string `protobuf:"bytes,6,opt,name=cachedToken,proto3" json:"cachedToken,omitempty"`
// maximum age of JWT tokens in seconds (from management server)
MaxTokenAge int64 `protobuf:"varint,7,opt,name=maxTokenAge,proto3" json:"maxTokenAge,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *RequestJWTAuthResponse) Reset() {
*x = RequestJWTAuthResponse{}
mi := &file_daemon_proto_msgTypes[72]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *RequestJWTAuthResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*RequestJWTAuthResponse) ProtoMessage() {}
func (x *RequestJWTAuthResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[72]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use RequestJWTAuthResponse.ProtoReflect.Descriptor instead.
func (*RequestJWTAuthResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{72}
}
func (x *RequestJWTAuthResponse) GetVerificationURI() string {
if x != nil {
return x.VerificationURI
}
return ""
}
func (x *RequestJWTAuthResponse) GetVerificationURIComplete() string {
if x != nil {
return x.VerificationURIComplete
}
return ""
}
func (x *RequestJWTAuthResponse) GetUserCode() string {
if x != nil {
return x.UserCode
}
return ""
}
func (x *RequestJWTAuthResponse) GetDeviceCode() string {
if x != nil {
return x.DeviceCode
}
return ""
}
func (x *RequestJWTAuthResponse) GetExpiresIn() int64 {
if x != nil {
return x.ExpiresIn
}
return 0
}
func (x *RequestJWTAuthResponse) GetCachedToken() string {
if x != nil {
return x.CachedToken
}
return ""
}
func (x *RequestJWTAuthResponse) GetMaxTokenAge() int64 {
if x != nil {
return x.MaxTokenAge
}
return 0
}
// WaitJWTTokenRequest for waiting for authentication completion
type WaitJWTTokenRequest struct {
state protoimpl.MessageState `protogen:"open.v1"`
// device code from RequestJWTAuthResponse
DeviceCode string `protobuf:"bytes,1,opt,name=deviceCode,proto3" json:"deviceCode,omitempty"`
// user code for verification
UserCode string `protobuf:"bytes,2,opt,name=userCode,proto3" json:"userCode,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *WaitJWTTokenRequest) Reset() {
*x = WaitJWTTokenRequest{}
mi := &file_daemon_proto_msgTypes[73]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *WaitJWTTokenRequest) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*WaitJWTTokenRequest) ProtoMessage() {}
func (x *WaitJWTTokenRequest) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[73]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use WaitJWTTokenRequest.ProtoReflect.Descriptor instead.
func (*WaitJWTTokenRequest) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{73}
}
func (x *WaitJWTTokenRequest) GetDeviceCode() string {
if x != nil {
return x.DeviceCode
}
return ""
}
func (x *WaitJWTTokenRequest) GetUserCode() string {
if x != nil {
return x.UserCode
}
return ""
}
// WaitJWTTokenResponse contains the JWT token after authentication
type WaitJWTTokenResponse struct {
state protoimpl.MessageState `protogen:"open.v1"`
// JWT token (access token or ID token)
Token string `protobuf:"bytes,1,opt,name=token,proto3" json:"token,omitempty"`
// token type (e.g., "Bearer")
TokenType string `protobuf:"bytes,2,opt,name=tokenType,proto3" json:"tokenType,omitempty"`
// expiration time in seconds
ExpiresIn int64 `protobuf:"varint,3,opt,name=expiresIn,proto3" json:"expiresIn,omitempty"`
unknownFields protoimpl.UnknownFields
sizeCache protoimpl.SizeCache
}
func (x *WaitJWTTokenResponse) Reset() {
*x = WaitJWTTokenResponse{}
mi := &file_daemon_proto_msgTypes[74]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
func (x *WaitJWTTokenResponse) String() string {
return protoimpl.X.MessageStringOf(x)
}
func (*WaitJWTTokenResponse) ProtoMessage() {}
func (x *WaitJWTTokenResponse) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[74]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
ms.StoreMessageInfo(mi)
}
return ms
}
return mi.MessageOf(x)
}
// Deprecated: Use WaitJWTTokenResponse.ProtoReflect.Descriptor instead.
func (*WaitJWTTokenResponse) Descriptor() ([]byte, []int) {
return file_daemon_proto_rawDescGZIP(), []int{74}
}
func (x *WaitJWTTokenResponse) GetToken() string {
if x != nil {
return x.Token
}
return ""
}
func (x *WaitJWTTokenResponse) GetTokenType() string {
if x != nil {
return x.TokenType
}
return ""
}
func (x *WaitJWTTokenResponse) GetExpiresIn() int64 {
if x != nil {
return x.ExpiresIn
}
return 0
}
type PortInfo_Range struct {
state protoimpl.MessageState `protogen:"open.v1"`
Start uint32 `protobuf:"varint,1,opt,name=start,proto3" json:"start,omitempty"`
@@ -4784,7 +5064,7 @@ type PortInfo_Range struct {
func (x *PortInfo_Range) Reset() {
*x = PortInfo_Range{}
mi := &file_daemon_proto_msgTypes[72]
mi := &file_daemon_proto_msgTypes[76]
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
ms.StoreMessageInfo(mi)
}
@@ -4796,7 +5076,7 @@ func (x *PortInfo_Range) String() string {
func (*PortInfo_Range) ProtoMessage() {}
func (x *PortInfo_Range) ProtoReflect() protoreflect.Message {
mi := &file_daemon_proto_msgTypes[72]
mi := &file_daemon_proto_msgTypes[76]
if x != nil {
ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x))
if ms.LoadMessageInfo() == nil {
@@ -4831,7 +5111,7 @@ var File_daemon_proto protoreflect.FileDescriptor
const file_daemon_proto_rawDesc = "" +
"\n" +
"\fdaemon.proto\x12\x06daemon\x1a google/protobuf/descriptor.proto\x1a\x1fgoogle/protobuf/timestamp.proto\x1a\x1egoogle/protobuf/duration.proto\"\x0e\n" +
"\fEmptyRequest\"\x94\x11\n" +
"\fEmptyRequest\"\xd4\x11\n" +
"\fLoginRequest\x12\x1a\n" +
"\bsetupKey\x18\x01 \x01(\tR\bsetupKey\x12&\n" +
"\fpreSharedKey\x18\x02 \x01(\tB\x02\x18\x01R\fpreSharedKey\x12$\n" +
@@ -4872,7 +5152,8 @@ const file_daemon_proto_rawDesc = "" +
"\renableSSHRoot\x18! \x01(\bH\x14R\renableSSHRoot\x88\x01\x01\x12)\n" +
"\renableSSHSFTP\x18\" \x01(\bH\x15R\renableSSHSFTP\x88\x01\x01\x12G\n" +
"\x1cenableSSHLocalPortForwarding\x18# \x01(\bH\x16R\x1cenableSSHLocalPortForwarding\x88\x01\x01\x12I\n" +
"\x1denableSSHRemotePortForwarding\x18$ \x01(\bH\x17R\x1denableSSHRemotePortForwarding\x88\x01\x01B\x13\n" +
"\x1denableSSHRemotePortForwarding\x18$ \x01(\bH\x17R\x1denableSSHRemotePortForwarding\x88\x01\x01\x12+\n" +
"\x0edisableSSHAuth\x18% \x01(\bH\x18R\x0edisableSSHAuth\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" +
@@ -4896,7 +5177,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0e_enableSSHRootB\x10\n" +
"\x0e_enableSSHSFTPB\x1f\n" +
"\x1d_enableSSHLocalPortForwardingB \n" +
"\x1e_enableSSHRemotePortForwarding\"\xb5\x01\n" +
"\x1e_enableSSHRemotePortForwardingB\x11\n" +
"\x0f_disableSSHAuth\"\xb5\x01\n" +
"\rLoginResponse\x12$\n" +
"\rneedsSSOLogin\x18\x01 \x01(\bR\rneedsSSOLogin\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\x12(\n" +
@@ -4929,7 +5211,7 @@ const file_daemon_proto_rawDesc = "" +
"\fDownResponse\"P\n" +
"\x10GetConfigRequest\x12 \n" +
"\vprofileName\x18\x01 \x01(\tR\vprofileName\x12\x1a\n" +
"\busername\x18\x02 \x01(\tR\busername\"\x8b\b\n" +
"\busername\x18\x02 \x01(\tR\busername\"\xb3\b\n" +
"\x11GetConfigResponse\x12$\n" +
"\rmanagementUrl\x18\x01 \x01(\tR\rmanagementUrl\x12\x1e\n" +
"\n" +
@@ -4958,7 +5240,8 @@ const file_daemon_proto_rawDesc = "" +
"\renableSSHRoot\x18\x15 \x01(\bR\renableSSHRoot\x12$\n" +
"\renableSSHSFTP\x18\x18 \x01(\bR\renableSSHSFTP\x12B\n" +
"\x1cenableSSHLocalPortForwarding\x18\x16 \x01(\bR\x1cenableSSHLocalPortForwarding\x12D\n" +
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\"\xfe\x05\n" +
"\x1denableSSHRemotePortForwarding\x18\x17 \x01(\bR\x1denableSSHRemotePortForwarding\x12&\n" +
"\x0edisableSSHAuth\x18\x19 \x01(\bR\x0edisableSSHAuth\"\xfe\x05\n" +
"\tPeerState\x12\x0e\n" +
"\x02IP\x18\x01 \x01(\tR\x02IP\x12\x16\n" +
"\x06pubKey\x18\x02 \x01(\tR\x06pubKey\x12\x1e\n" +
@@ -5161,7 +5444,7 @@ const file_daemon_proto_rawDesc = "" +
"\busername\x18\x02 \x01(\tH\x01R\busername\x88\x01\x01B\x0e\n" +
"\f_profileNameB\v\n" +
"\t_username\"\x17\n" +
"\x15SwitchProfileResponse\"\xcd\x0f\n" +
"\x15SwitchProfileResponse\"\x8d\x10\n" +
"\x10SetConfigRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
"\vprofileName\x18\x02 \x01(\tR\vprofileName\x12$\n" +
@@ -5198,7 +5481,8 @@ const file_daemon_proto_rawDesc = "" +
"\renableSSHRoot\x18\x1d \x01(\bH\x12R\renableSSHRoot\x88\x01\x01\x12)\n" +
"\renableSSHSFTP\x18\x1e \x01(\bH\x13R\renableSSHSFTP\x88\x01\x01\x12A\n" +
"\x19enableSSHLocalPortForward\x18\x1f \x01(\bH\x14R\x19enableSSHLocalPortForward\x88\x01\x01\x12C\n" +
"\x1aenableSSHRemotePortForward\x18 \x01(\bH\x15R\x1aenableSSHRemotePortForward\x88\x01\x01B\x13\n" +
"\x1aenableSSHRemotePortForward\x18 \x01(\bH\x15R\x1aenableSSHRemotePortForward\x88\x01\x01\x12+\n" +
"\x0edisableSSHAuth\x18! \x01(\bH\x16R\x0edisableSSHAuth\x88\x01\x01B\x13\n" +
"\x11_rosenpassEnabledB\x10\n" +
"\x0e_interfaceNameB\x10\n" +
"\x0e_wireguardPortB\x17\n" +
@@ -5220,7 +5504,8 @@ const file_daemon_proto_rawDesc = "" +
"\x0e_enableSSHRootB\x10\n" +
"\x0e_enableSSHSFTPB\x1c\n" +
"\x1a_enableSSHLocalPortForwardB\x1d\n" +
"\x1b_enableSSHRemotePortForward\"\x13\n" +
"\x1b_enableSSHRemotePortForwardB\x11\n" +
"\x0f_disableSSHAuth\"\x13\n" +
"\x11SetConfigResponse\"Q\n" +
"\x11AddProfileRequest\x12\x1a\n" +
"\busername\x18\x01 \x01(\tR\busername\x12 \n" +
@@ -5259,7 +5544,27 @@ const file_daemon_proto_rawDesc = "" +
"sshHostKey\x12\x16\n" +
"\x06peerIP\x18\x02 \x01(\tR\x06peerIP\x12\x1a\n" +
"\bpeerFQDN\x18\x03 \x01(\tR\bpeerFQDN\x12\x14\n" +
"\x05found\x18\x04 \x01(\bR\x05found*b\n" +
"\x05found\x18\x04 \x01(\bR\x05found\"\x17\n" +
"\x15RequestJWTAuthRequest\"\x9a\x02\n" +
"\x16RequestJWTAuthResponse\x12(\n" +
"\x0fverificationURI\x18\x01 \x01(\tR\x0fverificationURI\x128\n" +
"\x17verificationURIComplete\x18\x02 \x01(\tR\x17verificationURIComplete\x12\x1a\n" +
"\buserCode\x18\x03 \x01(\tR\buserCode\x12\x1e\n" +
"\n" +
"deviceCode\x18\x04 \x01(\tR\n" +
"deviceCode\x12\x1c\n" +
"\texpiresIn\x18\x05 \x01(\x03R\texpiresIn\x12 \n" +
"\vcachedToken\x18\x06 \x01(\tR\vcachedToken\x12 \n" +
"\vmaxTokenAge\x18\a \x01(\x03R\vmaxTokenAge\"Q\n" +
"\x13WaitJWTTokenRequest\x12\x1e\n" +
"\n" +
"deviceCode\x18\x01 \x01(\tR\n" +
"deviceCode\x12\x1a\n" +
"\buserCode\x18\x02 \x01(\tR\buserCode\"h\n" +
"\x14WaitJWTTokenResponse\x12\x14\n" +
"\x05token\x18\x01 \x01(\tR\x05token\x12\x1c\n" +
"\ttokenType\x18\x02 \x01(\tR\ttokenType\x12\x1c\n" +
"\texpiresIn\x18\x03 \x01(\x03R\texpiresIn*b\n" +
"\bLogLevel\x12\v\n" +
"\aUNKNOWN\x10\x00\x12\t\n" +
"\x05PANIC\x10\x01\x12\t\n" +
@@ -5268,7 +5573,7 @@ const file_daemon_proto_rawDesc = "" +
"\x04WARN\x10\x04\x12\b\n" +
"\x04INFO\x10\x05\x12\t\n" +
"\x05DEBUG\x10\x06\x12\t\n" +
"\x05TRACE\x10\a2\xeb\x10\n" +
"\x05TRACE\x10\a2\x8b\x12\n" +
"\rDaemonService\x126\n" +
"\x05Login\x12\x14.daemon.LoginRequest\x1a\x15.daemon.LoginResponse\"\x00\x12K\n" +
"\fWaitSSOLogin\x12\x1b.daemon.WaitSSOLoginRequest\x1a\x1c.daemon.WaitSSOLoginResponse\"\x00\x12-\n" +
@@ -5301,7 +5606,9 @@ const file_daemon_proto_rawDesc = "" +
"\x10GetActiveProfile\x12\x1f.daemon.GetActiveProfileRequest\x1a .daemon.GetActiveProfileResponse\"\x00\x129\n" +
"\x06Logout\x12\x15.daemon.LogoutRequest\x1a\x16.daemon.LogoutResponse\"\x00\x12H\n" +
"\vGetFeatures\x12\x1a.daemon.GetFeaturesRequest\x1a\x1b.daemon.GetFeaturesResponse\"\x00\x12Z\n" +
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00B\bZ\x06/protob\x06proto3"
"\x11GetPeerSSHHostKey\x12 .daemon.GetPeerSSHHostKeyRequest\x1a!.daemon.GetPeerSSHHostKeyResponse\"\x00\x12Q\n" +
"\x0eRequestJWTAuth\x12\x1d.daemon.RequestJWTAuthRequest\x1a\x1e.daemon.RequestJWTAuthResponse\"\x00\x12K\n" +
"\fWaitJWTToken\x12\x1b.daemon.WaitJWTTokenRequest\x1a\x1c.daemon.WaitJWTTokenResponse\"\x00B\bZ\x06/protob\x06proto3"
var (
file_daemon_proto_rawDescOnce sync.Once
@@ -5316,7 +5623,7 @@ func file_daemon_proto_rawDescGZIP() []byte {
}
var file_daemon_proto_enumTypes = make([]protoimpl.EnumInfo, 3)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 74)
var file_daemon_proto_msgTypes = make([]protoimpl.MessageInfo, 78)
var file_daemon_proto_goTypes = []any{
(LogLevel)(0), // 0: daemon.LogLevel
(SystemEvent_Severity)(0), // 1: daemon.SystemEvent.Severity
@@ -5392,18 +5699,22 @@ var file_daemon_proto_goTypes = []any{
(*GetFeaturesResponse)(nil), // 71: daemon.GetFeaturesResponse
(*GetPeerSSHHostKeyRequest)(nil), // 72: daemon.GetPeerSSHHostKeyRequest
(*GetPeerSSHHostKeyResponse)(nil), // 73: daemon.GetPeerSSHHostKeyResponse
nil, // 74: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 75: daemon.PortInfo.Range
nil, // 76: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 77: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 78: google.protobuf.Timestamp
(*RequestJWTAuthRequest)(nil), // 74: daemon.RequestJWTAuthRequest
(*RequestJWTAuthResponse)(nil), // 75: daemon.RequestJWTAuthResponse
(*WaitJWTTokenRequest)(nil), // 76: daemon.WaitJWTTokenRequest
(*WaitJWTTokenResponse)(nil), // 77: daemon.WaitJWTTokenResponse
nil, // 78: daemon.Network.ResolvedIPsEntry
(*PortInfo_Range)(nil), // 79: daemon.PortInfo.Range
nil, // 80: daemon.SystemEvent.MetadataEntry
(*durationpb.Duration)(nil), // 81: google.protobuf.Duration
(*timestamppb.Timestamp)(nil), // 82: google.protobuf.Timestamp
}
var file_daemon_proto_depIdxs = []int32{
77, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
81, // 0: daemon.LoginRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
22, // 1: daemon.StatusResponse.fullStatus:type_name -> daemon.FullStatus
78, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
78, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
77, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
82, // 2: daemon.PeerState.connStatusUpdate:type_name -> google.protobuf.Timestamp
82, // 3: daemon.PeerState.lastWireguardHandshake:type_name -> google.protobuf.Timestamp
81, // 4: daemon.PeerState.latency:type_name -> google.protobuf.Duration
19, // 5: daemon.FullStatus.managementState:type_name -> daemon.ManagementState
18, // 6: daemon.FullStatus.signalState:type_name -> daemon.SignalState
17, // 7: daemon.FullStatus.localPeerState:type_name -> daemon.LocalPeerState
@@ -5412,8 +5723,8 @@ var file_daemon_proto_depIdxs = []int32{
21, // 10: daemon.FullStatus.dns_servers:type_name -> daemon.NSGroupState
52, // 11: daemon.FullStatus.events:type_name -> daemon.SystemEvent
28, // 12: daemon.ListNetworksResponse.routes:type_name -> daemon.Network
74, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
75, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
78, // 13: daemon.Network.resolvedIPs:type_name -> daemon.Network.ResolvedIPsEntry
79, // 14: daemon.PortInfo.range:type_name -> daemon.PortInfo.Range
29, // 15: daemon.ForwardingRule.destinationPort:type_name -> daemon.PortInfo
29, // 16: daemon.ForwardingRule.translatedPort:type_name -> daemon.PortInfo
30, // 17: daemon.ForwardingRulesResponse.rules:type_name -> daemon.ForwardingRule
@@ -5424,10 +5735,10 @@ var file_daemon_proto_depIdxs = []int32{
49, // 22: daemon.TracePacketResponse.stages:type_name -> daemon.TraceStage
1, // 23: daemon.SystemEvent.severity:type_name -> daemon.SystemEvent.Severity
2, // 24: daemon.SystemEvent.category:type_name -> daemon.SystemEvent.Category
78, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
76, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
82, // 25: daemon.SystemEvent.timestamp:type_name -> google.protobuf.Timestamp
80, // 26: daemon.SystemEvent.metadata:type_name -> daemon.SystemEvent.MetadataEntry
52, // 27: daemon.GetEventsResponse.events:type_name -> daemon.SystemEvent
77, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
81, // 28: daemon.SetConfigRequest.dnsRouteInterval:type_name -> google.protobuf.Duration
65, // 29: daemon.ListProfilesResponse.profiles:type_name -> daemon.Profile
27, // 30: daemon.Network.ResolvedIPsEntry.value:type_name -> daemon.IPList
4, // 31: daemon.DaemonService.Login:input_type -> daemon.LoginRequest
@@ -5459,37 +5770,41 @@ var file_daemon_proto_depIdxs = []int32{
68, // 57: daemon.DaemonService.Logout:input_type -> daemon.LogoutRequest
70, // 58: daemon.DaemonService.GetFeatures:input_type -> daemon.GetFeaturesRequest
72, // 59: daemon.DaemonService.GetPeerSSHHostKey:input_type -> daemon.GetPeerSSHHostKeyRequest
5, // 60: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
7, // 61: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
9, // 62: daemon.DaemonService.Up:output_type -> daemon.UpResponse
11, // 63: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
13, // 64: daemon.DaemonService.Down:output_type -> daemon.DownResponse
15, // 65: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
24, // 66: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
26, // 67: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
26, // 68: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 69: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
33, // 70: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
35, // 71: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
37, // 72: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
40, // 73: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
42, // 74: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
44, // 75: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
46, // 76: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
50, // 77: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
52, // 78: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
54, // 79: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
56, // 80: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
58, // 81: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
60, // 82: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
62, // 83: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
64, // 84: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
67, // 85: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
69, // 86: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
71, // 87: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
73, // 88: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
60, // [60:89] is the sub-list for method output_type
31, // [31:60] is the sub-list for method input_type
74, // 60: daemon.DaemonService.RequestJWTAuth:input_type -> daemon.RequestJWTAuthRequest
76, // 61: daemon.DaemonService.WaitJWTToken:input_type -> daemon.WaitJWTTokenRequest
5, // 62: daemon.DaemonService.Login:output_type -> daemon.LoginResponse
7, // 63: daemon.DaemonService.WaitSSOLogin:output_type -> daemon.WaitSSOLoginResponse
9, // 64: daemon.DaemonService.Up:output_type -> daemon.UpResponse
11, // 65: daemon.DaemonService.Status:output_type -> daemon.StatusResponse
13, // 66: daemon.DaemonService.Down:output_type -> daemon.DownResponse
15, // 67: daemon.DaemonService.GetConfig:output_type -> daemon.GetConfigResponse
24, // 68: daemon.DaemonService.ListNetworks:output_type -> daemon.ListNetworksResponse
26, // 69: daemon.DaemonService.SelectNetworks:output_type -> daemon.SelectNetworksResponse
26, // 70: daemon.DaemonService.DeselectNetworks:output_type -> daemon.SelectNetworksResponse
31, // 71: daemon.DaemonService.ForwardingRules:output_type -> daemon.ForwardingRulesResponse
33, // 72: daemon.DaemonService.DebugBundle:output_type -> daemon.DebugBundleResponse
35, // 73: daemon.DaemonService.GetLogLevel:output_type -> daemon.GetLogLevelResponse
37, // 74: daemon.DaemonService.SetLogLevel:output_type -> daemon.SetLogLevelResponse
40, // 75: daemon.DaemonService.ListStates:output_type -> daemon.ListStatesResponse
42, // 76: daemon.DaemonService.CleanState:output_type -> daemon.CleanStateResponse
44, // 77: daemon.DaemonService.DeleteState:output_type -> daemon.DeleteStateResponse
46, // 78: daemon.DaemonService.SetSyncResponsePersistence:output_type -> daemon.SetSyncResponsePersistenceResponse
50, // 79: daemon.DaemonService.TracePacket:output_type -> daemon.TracePacketResponse
52, // 80: daemon.DaemonService.SubscribeEvents:output_type -> daemon.SystemEvent
54, // 81: daemon.DaemonService.GetEvents:output_type -> daemon.GetEventsResponse
56, // 82: daemon.DaemonService.SwitchProfile:output_type -> daemon.SwitchProfileResponse
58, // 83: daemon.DaemonService.SetConfig:output_type -> daemon.SetConfigResponse
60, // 84: daemon.DaemonService.AddProfile:output_type -> daemon.AddProfileResponse
62, // 85: daemon.DaemonService.RemoveProfile:output_type -> daemon.RemoveProfileResponse
64, // 86: daemon.DaemonService.ListProfiles:output_type -> daemon.ListProfilesResponse
67, // 87: daemon.DaemonService.GetActiveProfile:output_type -> daemon.GetActiveProfileResponse
69, // 88: daemon.DaemonService.Logout:output_type -> daemon.LogoutResponse
71, // 89: daemon.DaemonService.GetFeatures:output_type -> daemon.GetFeaturesResponse
73, // 90: daemon.DaemonService.GetPeerSSHHostKey:output_type -> daemon.GetPeerSSHHostKeyResponse
75, // 91: daemon.DaemonService.RequestJWTAuth:output_type -> daemon.RequestJWTAuthResponse
77, // 92: daemon.DaemonService.WaitJWTToken:output_type -> daemon.WaitJWTTokenResponse
62, // [62:93] is the sub-list for method output_type
31, // [31:62] is the sub-list for method input_type
31, // [31:31] is the sub-list for extension type_name
31, // [31:31] is the sub-list for extension extendee
0, // [0:31] is the sub-list for field type_name
@@ -5518,7 +5833,7 @@ func file_daemon_proto_init() {
GoPackagePath: reflect.TypeOf(x{}).PkgPath(),
RawDescriptor: unsafe.Slice(unsafe.StringData(file_daemon_proto_rawDesc), len(file_daemon_proto_rawDesc)),
NumEnums: 3,
NumMessages: 74,
NumMessages: 78,
NumExtensions: 0,
NumServices: 1,
},

View File

@@ -87,6 +87,12 @@ service DaemonService {
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
rpc GetPeerSSHHostKey(GetPeerSSHHostKeyRequest) returns (GetPeerSSHHostKeyResponse) {}
// RequestJWTAuth initiates JWT authentication flow for SSH
rpc RequestJWTAuth(RequestJWTAuthRequest) returns (RequestJWTAuthResponse) {}
// WaitJWTToken waits for JWT authentication completion
rpc WaitJWTToken(WaitJWTTokenRequest) returns (WaitJWTTokenResponse) {}
}
@@ -166,6 +172,7 @@ message LoginRequest {
optional bool enableSSHSFTP = 34;
optional bool enableSSHLocalPortForwarding = 35;
optional bool enableSSHRemotePortForwarding = 36;
optional bool disableSSHAuth = 37;
}
message LoginResponse {
@@ -268,6 +275,8 @@ message GetConfigResponse {
bool enableSSHLocalPortForwarding = 22;
bool enableSSHRemotePortForwarding = 23;
bool disableSSHAuth = 25;
}
// PeerState contains the latest state of a peer
@@ -612,6 +621,7 @@ message SetConfigRequest {
optional bool enableSSHSFTP = 30;
optional bool enableSSHLocalPortForward = 31;
optional bool enableSSHRemotePortForward = 32;
optional bool disableSSHAuth = 33;
}
message SetConfigResponse{}
@@ -681,3 +691,43 @@ message GetPeerSSHHostKeyResponse {
// indicates if the SSH host key was found
bool found = 4;
}
// RequestJWTAuthRequest for initiating JWT authentication flow
message RequestJWTAuthRequest {
}
// RequestJWTAuthResponse contains authentication flow information
message RequestJWTAuthResponse {
// verification URI for user authentication
string verificationURI = 1;
// complete verification URI (with embedded user code)
string verificationURIComplete = 2;
// user code to enter on verification URI
string userCode = 3;
// device code for polling
string deviceCode = 4;
// expiration time in seconds
int64 expiresIn = 5;
// if a cached token is available, it will be returned here
string cachedToken = 6;
// maximum age of JWT tokens in seconds (from management server)
int64 maxTokenAge = 7;
}
// WaitJWTTokenRequest for waiting for authentication completion
message WaitJWTTokenRequest {
// device code from RequestJWTAuthResponse
string deviceCode = 1;
// user code for verification
string userCode = 2;
}
// WaitJWTTokenResponse contains the JWT token after authentication
message WaitJWTTokenResponse {
// JWT token (access token or ID token)
string token = 1;
// token type (e.g., "Bearer")
string tokenType = 2;
// expiration time in seconds
int64 expiresIn = 3;
}

View File

@@ -66,6 +66,10 @@ type DaemonServiceClient interface {
GetFeatures(ctx context.Context, in *GetFeaturesRequest, opts ...grpc.CallOption) (*GetFeaturesResponse, error)
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
GetPeerSSHHostKey(ctx context.Context, in *GetPeerSSHHostKeyRequest, opts ...grpc.CallOption) (*GetPeerSSHHostKeyResponse, error)
// RequestJWTAuth initiates JWT authentication flow for SSH
RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error)
}
type daemonServiceClient struct {
@@ -360,6 +364,24 @@ func (c *daemonServiceClient) GetPeerSSHHostKey(ctx context.Context, in *GetPeer
return out, nil
}
func (c *daemonServiceClient) RequestJWTAuth(ctx context.Context, in *RequestJWTAuthRequest, opts ...grpc.CallOption) (*RequestJWTAuthResponse, error) {
out := new(RequestJWTAuthResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/RequestJWTAuth", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
func (c *daemonServiceClient) WaitJWTToken(ctx context.Context, in *WaitJWTTokenRequest, opts ...grpc.CallOption) (*WaitJWTTokenResponse, error) {
out := new(WaitJWTTokenResponse)
err := c.cc.Invoke(ctx, "/daemon.DaemonService/WaitJWTToken", in, out, opts...)
if err != nil {
return nil, err
}
return out, nil
}
// DaemonServiceServer is the server API for DaemonService service.
// All implementations must embed UnimplementedDaemonServiceServer
// for forward compatibility
@@ -412,6 +434,10 @@ type DaemonServiceServer interface {
GetFeatures(context.Context, *GetFeaturesRequest) (*GetFeaturesResponse, error)
// GetPeerSSHHostKey retrieves SSH host key for a specific peer
GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error)
// RequestJWTAuth initiates JWT authentication flow for SSH
RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error)
// WaitJWTToken waits for JWT authentication completion
WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error)
mustEmbedUnimplementedDaemonServiceServer()
}
@@ -506,6 +532,12 @@ func (UnimplementedDaemonServiceServer) GetFeatures(context.Context, *GetFeature
func (UnimplementedDaemonServiceServer) GetPeerSSHHostKey(context.Context, *GetPeerSSHHostKeyRequest) (*GetPeerSSHHostKeyResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method GetPeerSSHHostKey not implemented")
}
func (UnimplementedDaemonServiceServer) RequestJWTAuth(context.Context, *RequestJWTAuthRequest) (*RequestJWTAuthResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method RequestJWTAuth not implemented")
}
func (UnimplementedDaemonServiceServer) WaitJWTToken(context.Context, *WaitJWTTokenRequest) (*WaitJWTTokenResponse, error) {
return nil, status.Errorf(codes.Unimplemented, "method WaitJWTToken not implemented")
}
func (UnimplementedDaemonServiceServer) mustEmbedUnimplementedDaemonServiceServer() {}
// UnsafeDaemonServiceServer may be embedded to opt out of forward compatibility for this service.
@@ -1044,6 +1076,42 @@ func _DaemonService_GetPeerSSHHostKey_Handler(srv interface{}, ctx context.Conte
return interceptor(ctx, in, info, handler)
}
func _DaemonService_RequestJWTAuth_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(RequestJWTAuthRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/RequestJWTAuth",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).RequestJWTAuth(ctx, req.(*RequestJWTAuthRequest))
}
return interceptor(ctx, in, info, handler)
}
func _DaemonService_WaitJWTToken_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) {
in := new(WaitJWTTokenRequest)
if err := dec(in); err != nil {
return nil, err
}
if interceptor == nil {
return srv.(DaemonServiceServer).WaitJWTToken(ctx, in)
}
info := &grpc.UnaryServerInfo{
Server: srv,
FullMethod: "/daemon.DaemonService/WaitJWTToken",
}
handler := func(ctx context.Context, req interface{}) (interface{}, error) {
return srv.(DaemonServiceServer).WaitJWTToken(ctx, req.(*WaitJWTTokenRequest))
}
return interceptor(ctx, in, info, handler)
}
// DaemonService_ServiceDesc is the grpc.ServiceDesc for DaemonService service.
// It's only intended for direct use with grpc.RegisterService,
// and not to be introspected or modified (even as a copy)
@@ -1163,6 +1231,14 @@ var DaemonService_ServiceDesc = grpc.ServiceDesc{
MethodName: "GetPeerSSHHostKey",
Handler: _DaemonService_GetPeerSSHHostKey_Handler,
},
{
MethodName: "RequestJWTAuth",
Handler: _DaemonService_RequestJWTAuth_Handler,
},
{
MethodName: "WaitJWTToken",
Handler: _DaemonService_WaitJWTToken_Handler,
},
},
Streams: []grpc.StreamDesc{
{

View File

@@ -0,0 +1,73 @@
package server
import (
"sync"
"time"
"github.com/awnumar/memguard"
log "github.com/sirupsen/logrus"
)
type jwtCache struct {
mu sync.RWMutex
enclave *memguard.Enclave
expiresAt time.Time
timer *time.Timer
maxTokenSize int
}
func newJWTCache() *jwtCache {
return &jwtCache{
maxTokenSize: 8192,
}
}
func (c *jwtCache) store(token string, maxAge time.Duration) {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanup()
if c.timer != nil {
c.timer.Stop()
}
tokenBytes := []byte(token)
c.enclave = memguard.NewEnclave(tokenBytes)
c.expiresAt = time.Now().Add(maxAge)
c.timer = time.AfterFunc(maxAge, func() {
c.mu.Lock()
defer c.mu.Unlock()
c.cleanup()
c.timer = nil
log.Debugf("JWT token cache expired after %v, securely wiped from memory", maxAge)
})
}
func (c *jwtCache) get() (string, bool) {
c.mu.RLock()
defer c.mu.RUnlock()
if c.enclave == nil || time.Now().After(c.expiresAt) {
return "", false
}
buffer, err := c.enclave.Open()
if err != nil {
log.Debugf("Failed to open JWT token enclave: %v", err)
return "", false
}
defer buffer.Destroy()
token := string(buffer.Bytes())
return token, true
}
// cleanup destroys the secure enclave, must be called with lock held
func (c *jwtCache) cleanup() {
if c.enclave != nil {
c.enclave = nil
}
}

View File

@@ -11,8 +11,8 @@ import (
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
type selectRoute struct {

View File

@@ -46,6 +46,9 @@ const (
defaultMaxRetryTime = 14 * 24 * time.Hour
defaultRetryMultiplier = 1.7
// JWT token cache TTL for the client daemon
defaultJWTCacheTTL = 5 * time.Minute
errRestoreResidualState = "failed to restore residual state: %v"
errProfilesDisabled = "profiles are disabled, you cannot use this feature without profiles enabled"
errUpdateSettingsDisabled = "update settings are disabled, you cannot use this feature without update settings enabled"
@@ -81,6 +84,8 @@ type Server struct {
profileManager *profilemanager.ServiceManager
profilesDisabled bool
updateSettingsDisabled bool
jwtCache *jwtCache
}
type oauthAuthFlow struct {
@@ -100,6 +105,7 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled,
jwtCache: newJWTCache(),
}
}
@@ -370,6 +376,9 @@ func (s *Server) SetConfig(callerCtx context.Context, msg *proto.SetConfigReques
config.EnableSSHSFTP = msg.EnableSSHSFTP
config.EnableSSHLocalPortForwarding = msg.EnableSSHLocalPortForward
config.EnableSSHRemotePortForwarding = msg.EnableSSHRemotePortForward
if msg.DisableSSHAuth != nil {
config.DisableSSHAuth = msg.DisableSSHAuth
}
if msg.Mtu != nil {
mtu := uint16(*msg.Mtu)
@@ -486,7 +495,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
return nil, err
}
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(context.TODO()) {
if s.oauthAuthFlow.flow != nil && s.oauthAuthFlow.flow.GetClientID(ctx) == oAuthFlow.GetClientID(ctx) {
if s.oauthAuthFlow.expiresAt.After(time.Now().Add(90 * time.Second)) {
log.Debugf("using previous oauth flow info")
return &proto.LoginResponse{
@@ -503,7 +512,7 @@ func (s *Server) Login(callerCtx context.Context, msg *proto.LoginRequest) (*pro
}
}
authInfo, err := oAuthFlow.RequestAuthInfo(context.TODO())
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
log.Errorf("getting a request OAuth flow failed: %v", err)
return nil, err
@@ -1077,28 +1086,41 @@ func (s *Server) GetPeerSSHHostKey(
}
s.mutex.Lock()
defer s.mutex.Unlock()
connectClient := s.connectClient
statusRecorder := s.statusRecorder
s.mutex.Unlock()
response := &proto.GetPeerSSHHostKeyResponse{
Found: false,
if connectClient == nil {
return nil, errors.New("client not initialized")
}
if s.statusRecorder == nil {
engine := connectClient.Engine()
if engine == nil {
return nil, errors.New("engine not started")
}
peerAddress := req.GetPeerAddress()
hostKey, found := engine.GetPeerSSHKey(peerAddress)
response := &proto.GetPeerSSHHostKeyResponse{
Found: found,
}
if !found {
return response, nil
}
fullStatus := s.statusRecorder.GetFullStatus()
peerAddress := req.GetPeerAddress()
response.SshHostKey = hostKey
// Search for peer by IP or FQDN
if statusRecorder == nil {
return response, nil
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if len(peerState.SSHHostKey) > 0 {
response.SshHostKey = peerState.SSHHostKey
response.PeerIP = peerState.IP
response.PeerFQDN = peerState.FQDN
response.Found = true
}
response.PeerIP = peerState.IP
response.PeerFQDN = peerState.FQDN
break
}
}
@@ -1106,6 +1128,137 @@ func (s *Server) GetPeerSSHHostKey(
return response, nil
}
// getJWTCacheTTL returns the JWT cache TTL from environment variable or default
// NB_SSH_JWT_CACHE_TTL=0 disables caching
// NB_SSH_JWT_CACHE_TTL=<seconds> sets custom cache TTL
func getJWTCacheTTL() time.Duration {
envValue := os.Getenv("NB_SSH_JWT_CACHE_TTL")
if envValue == "" {
return defaultJWTCacheTTL
}
seconds, err := strconv.Atoi(envValue)
if err != nil {
log.Warnf("invalid NB_SSH_JWT_CACHE_TTL value %s, using default: %v", envValue, defaultJWTCacheTTL)
return defaultJWTCacheTTL
}
if seconds == 0 {
log.Info("SSH JWT cache disabled via NB_SSH_JWT_CACHE_TTL=0")
return 0
}
ttl := time.Duration(seconds) * time.Second
log.Infof("SSH JWT cache TTL set to %v via NB_SSH_JWT_CACHE_TTL", ttl)
return ttl
}
// RequestJWTAuth initiates JWT authentication flow for SSH
func (s *Server) RequestJWTAuth(
ctx context.Context,
_ *proto.RequestJWTAuthRequest,
) (*proto.RequestJWTAuthResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
config := s.config
s.mutex.Unlock()
if config == nil {
return nil, gstatus.Errorf(codes.FailedPrecondition, "client is not configured")
}
jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 {
if cachedToken, found := s.jwtCache.get(); found {
log.Debugf("JWT token found in cache, returning cached token for SSH authentication")
return &proto.RequestJWTAuthResponse{
CachedToken: cachedToken,
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
}
isDesktop := isUnixRunningDesktop()
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isDesktop)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to create OAuth flow: %v", err)
}
authInfo, err := oAuthFlow.RequestAuthInfo(ctx)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to request auth info: %v", err)
}
s.mutex.Lock()
s.oauthAuthFlow.flow = oAuthFlow
s.oauthAuthFlow.info = authInfo
s.oauthAuthFlow.expiresAt = time.Now().Add(time.Duration(authInfo.ExpiresIn) * time.Second)
s.mutex.Unlock()
return &proto.RequestJWTAuthResponse{
VerificationURI: authInfo.VerificationURI,
VerificationURIComplete: authInfo.VerificationURIComplete,
UserCode: authInfo.UserCode,
DeviceCode: authInfo.DeviceCode,
ExpiresIn: int64(authInfo.ExpiresIn),
MaxTokenAge: int64(jwtCacheTTL.Seconds()),
}, nil
}
// WaitJWTToken waits for JWT authentication completion
func (s *Server) WaitJWTToken(
ctx context.Context,
req *proto.WaitJWTTokenRequest,
) (*proto.WaitJWTTokenResponse, error) {
if ctx.Err() != nil {
return nil, ctx.Err()
}
s.mutex.Lock()
oAuthFlow := s.oauthAuthFlow.flow
authInfo := s.oauthAuthFlow.info
s.mutex.Unlock()
if oAuthFlow == nil || authInfo.DeviceCode != req.DeviceCode {
return nil, gstatus.Errorf(codes.InvalidArgument, "invalid device code or no active auth flow")
}
tokenInfo, err := oAuthFlow.WaitToken(ctx, authInfo)
if err != nil {
return nil, gstatus.Errorf(codes.Internal, "failed to get token: %v", err)
}
token := tokenInfo.GetTokenToUse()
jwtCacheTTL := getJWTCacheTTL()
if jwtCacheTTL > 0 {
s.jwtCache.store(token, jwtCacheTTL)
log.Debugf("JWT token cached for SSH authentication, TTL: %v", jwtCacheTTL)
} else {
log.Debug("JWT caching disabled, not storing token")
}
s.mutex.Lock()
s.oauthAuthFlow = oauthAuthFlow{}
s.mutex.Unlock()
return &proto.WaitJWTTokenResponse{
Token: tokenInfo.GetTokenToUse(),
TokenType: tokenInfo.TokenType,
ExpiresIn: int64(tokenInfo.ExpiresIn),
}, nil
}
func isUnixRunningDesktop() bool {
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
return false
}
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
}
func (s *Server) runProbes() {
if s.connectClient == nil {
return
@@ -1191,13 +1344,18 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
enableSSHRemotePortForwarding = *s.config.EnableSSHRemotePortForwarding
}
disableSSHAuth := false
if s.config.DisableSSHAuth != nil {
disableSSHAuth = *s.config.DisableSSHAuth
}
return &proto.GetConfigResponse{
ManagementUrl: managementURL.String(),
PreSharedKey: preSharedKey,
AdminURL: adminURL.String(),
InterfaceName: cfg.WgIface,
WireguardPort: int64(cfg.WgPort),
Mtu: int64(cfg.MTU),
Mtu: int64(cfg.MTU),
DisableAutoConnect: cfg.DisableAutoConnect,
ServerSSHAllowed: *cfg.ServerSSHAllowed,
RosenpassEnabled: cfg.RosenpassEnabled,
@@ -1214,6 +1372,7 @@ func (s *Server) GetConfig(ctx context.Context, req *proto.GetConfigRequest) (*p
EnableSSHSFTP: enableSSHSFTP,
EnableSSHLocalPortForwarding: enableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: enableSSHRemotePortForwarding,
DisableSSHAuth: disableSSHAuth,
}, nil
}

View File

@@ -6,9 +6,11 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/ssh/config"
)
func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&dns.ShutdownState{})
mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&config.ShutdownState{})
}

View File

@@ -8,6 +8,7 @@ import (
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/ssh/config"
)
func registerStates(mgr *statemanager.Manager) {
@@ -15,4 +16,5 @@ func registerStates(mgr *statemanager.Manager) {
mgr.RegisterState(&systemops.ShutdownState{})
mgr.RegisterState(&nftables.ShutdownState{})
mgr.RegisterState(&iptables.ShutdownState{})
mgr.RegisterState(&config.ShutdownState{})
}

View File

@@ -21,6 +21,8 @@ import (
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
)
const (
@@ -40,12 +42,10 @@ type Client struct {
windowsStdinMode uint32 // nolint:unused
}
// Close terminates the SSH connection
func (c *Client) Close() error {
return c.client.Close()
}
// OpenTerminal opens an interactive terminal session
func (c *Client) OpenTerminal(ctx context.Context) error {
session, err := c.client.NewSession()
if err != nil {
@@ -259,43 +259,29 @@ func (c *Client) createSession(ctx context.Context) (*ssh.Session, func(), error
return session, cleanup, nil
}
// Dial connects to the given ssh server with proper host key verification
func Dial(ctx context.Context, addr, user string) (*Client, error) {
hostKeyCallback, err := createHostKeyCallback(addr)
if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err)
// getDefaultDaemonAddr returns the daemon address from environment or default for the OS
func getDefaultDaemonAddr() string {
if addr := os.Getenv("NB_DAEMON_ADDR"); addr != "" {
return addr
}
config := &ssh.ClientConfig{
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: hostKeyCallback,
if runtime.GOOS == "windows" {
return DefaultDaemonAddrWindows
}
return dial(ctx, "tcp", addr, config)
}
// DialInsecure connects to the given ssh server without host key verification (for testing only)
func DialInsecure(ctx context.Context, addr, user string) (*Client, error) {
config := &ssh.ClientConfig{
User: user,
Timeout: 30 * time.Second,
HostKeyCallback: ssh.InsecureIgnoreHostKey(), // #nosec G106 - Only used for tests
}
return dial(ctx, "tcp", addr, config)
return DefaultDaemonAddr
}
// DialOptions contains options for SSH connections
type DialOptions struct {
KnownHostsFile string
IdentityFile string
DaemonAddr string
KnownHostsFile string
IdentityFile string
DaemonAddr string
SkipCachedToken bool
InsecureSkipVerify bool
}
// DialWithOptions connects to the given ssh server with specified options
func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
hostKeyCallback, err := createHostKeyCallbackWithOptions(addr, opts)
// Dial connects to the given ssh server with specified options
func Dial(ctx context.Context, addr, user string, opts DialOptions) (*Client, error) {
hostKeyCallback, err := createHostKeyCallback(opts)
if err != nil {
return nil, fmt.Errorf("create host key callback: %w", err)
}
@@ -306,7 +292,6 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
HostKeyCallback: hostKeyCallback,
}
// Add SSH key authentication if identity file is specified
if opts.IdentityFile != "" {
authMethod, err := createSSHKeyAuth(opts.IdentityFile)
if err != nil {
@@ -315,11 +300,16 @@ func DialWithOptions(ctx context.Context, addr, user string, opts DialOptions) (
config.Auth = append(config.Auth, authMethod)
}
return dial(ctx, "tcp", addr, config)
daemonAddr := opts.DaemonAddr
if daemonAddr == "" {
daemonAddr = getDefaultDaemonAddr()
}
return dialWithJWT(ctx, "tcp", addr, config, daemonAddr, opts.SkipCachedToken)
}
// dial establishes an SSH connection
func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
// dialSSH establishes an SSH connection without JWT authentication
func dialSSH(ctx context.Context, network, addr string, config *ssh.ClientConfig) (*Client, error) {
dialer := &net.Dialer{}
conn, err := dialer.DialContext(ctx, network, addr)
if err != nil {
@@ -340,143 +330,84 @@ func dial(ctx context.Context, network, addr string, config *ssh.ClientConfig) (
}, nil
}
// createHostKeyCallback creates a host key verification callback that checks daemon first, then known_hosts files
func createHostKeyCallback(addr string) (ssh.HostKeyCallback, error) {
daemonAddr := os.Getenv("NB_DAEMON_ADDR")
if daemonAddr == "" {
if runtime.GOOS == "windows" {
daemonAddr = DefaultDaemonAddrWindows
} else {
daemonAddr = DefaultDaemonAddr
}
// dialWithJWT establishes an SSH connection with optional JWT authentication based on server detection
func dialWithJWT(ctx context.Context, network, addr string, config *ssh.ClientConfig, daemonAddr string, skipCache bool) (*Client, error) {
host, portStr, err := net.SplitHostPort(addr)
if err != nil {
return nil, fmt.Errorf("parse address %s: %w", addr, err)
}
return createHostKeyCallbackWithDaemonAddr(addr, daemonAddr)
port, err := strconv.Atoi(portStr)
if err != nil {
return nil, fmt.Errorf("parse port %s: %w", portStr, err)
}
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
return nil, fmt.Errorf("SSH server detection failed: %w", err)
}
if !serverType.RequiresJWT() {
return dialSSH(ctx, network, addr, config)
}
jwtCtx, cancel := context.WithTimeout(ctx, config.Timeout)
defer cancel()
jwtToken, err := requestJWTToken(jwtCtx, daemonAddr, skipCache)
if err != nil {
return nil, fmt.Errorf("request JWT token: %w", err)
}
configWithJWT := nbssh.AddJWTAuth(config, jwtToken)
return dialSSH(ctx, network, addr, configWithJWT)
}
// createHostKeyCallbackWithDaemonAddr creates a host key verification callback with specified daemon address
func createHostKeyCallbackWithDaemonAddr(addr, daemonAddr string) (ssh.HostKeyCallback, error) {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// First try to get host key from NetBird daemon
if err := verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr); err == nil {
return nil
}
// requestJWTToken requests a JWT token from the NetBird daemon
func requestJWTToken(ctx context.Context, daemonAddr string, skipCache bool) (string, error) {
conn, err := connectToDaemon(daemonAddr)
if err != nil {
return "", fmt.Errorf("connect to daemon: %w", err)
}
defer conn.Close()
// Fallback to known_hosts files
knownHostsFiles := getKnownHostsFiles()
var hostKeyCallbacks []ssh.HostKeyCallback
for _, file := range knownHostsFiles {
if callback, err := knownhosts.New(file); err == nil {
hostKeyCallbacks = append(hostKeyCallbacks, callback)
}
}
// Try each known_hosts callback
for _, callback := range hostKeyCallbacks {
if err := callback(hostname, remote, key); err == nil {
return nil
}
}
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
}, nil
client := proto.NewDaemonServiceClient(conn)
return nbssh.RequestJWTToken(ctx, client, os.Stdout, os.Stderr, !skipCache)
}
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
client, err := connectToDaemon(daemonAddr)
conn, err := connectToDaemon(daemonAddr)
if err != nil {
return err
}
defer func() {
if err := client.Close(); err != nil {
if err := conn.Close(); err != nil {
log.Debugf("daemon connection close error: %v", err)
}
}()
addresses := buildAddressList(hostname, remote)
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses)
return verifyKeyWithDaemon(client, addresses, key)
client := proto.NewDaemonServiceClient(conn)
verifier := nbssh.NewDaemonHostKeyVerifier(client)
callback := nbssh.CreateHostKeyCallback(verifier)
return callback(hostname, remote, key)
}
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
addr := strings.TrimPrefix(daemonAddr, "tcp://")
conn, err := grpc.DialContext(
ctx,
strings.TrimPrefix(daemonAddr, "tcp://"),
conn, err := grpc.NewClient(
addr,
grpc.WithTransportCredentials(insecure.NewCredentials()),
grpc.WithBlock(),
)
if err != nil {
log.Debugf("failed to connect to NetBird daemon at %s: %v", daemonAddr, err)
log.Debugf("failed to create gRPC client for NetBird daemon at %s: %v", daemonAddr, err)
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
}
return conn, nil
}
func buildAddressList(hostname string, remote net.Addr) []string {
addresses := []string{hostname}
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
if host != hostname {
addresses = append(addresses, host)
}
}
return addresses
}
func verifyKeyWithDaemon(conn *grpc.ClientConn, addresses []string, key ssh.PublicKey) error {
client := proto.NewDaemonServiceClient(conn)
for _, addr := range addresses {
if err := checkAddressKey(client, addr, key); err == nil {
return nil
}
}
return fmt.Errorf("SSH host key not found or does not match in NetBird daemon")
}
func checkAddressKey(client proto.DaemonServiceClient, addr string, key ssh.PublicKey) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: addr,
})
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
if err != nil {
log.Debugf("daemon query error for %s: %v", addr, err)
return err
}
if !response.GetFound() {
log.Debugf("SSH host key not found in daemon for address: %s", addr)
return fmt.Errorf("key not found")
}
return compareKeys(response.GetSshHostKey(), key, addr)
}
func compareKeys(storedKeyData []byte, presentedKey ssh.PublicKey, addr string) error {
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
if err != nil {
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
return err
}
if presentedKey.Type() == storedKey.Type() && string(presentedKey.Marshal()) == string(storedKey.Marshal()) {
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
return nil
}
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), presentedKey.Type())
return fmt.Errorf("key mismatch")
}
// getKnownHostsFiles returns paths to known_hosts files in order of preference
func getKnownHostsFiles() []string {
var files []string
@@ -503,8 +434,12 @@ func getKnownHostsFiles() []string {
return files
}
// createHostKeyCallbackWithOptions creates a host key verification callback with custom options
func createHostKeyCallbackWithOptions(addr string, opts DialOptions) (ssh.HostKeyCallback, error) {
// createHostKeyCallback creates a host key verification callback
func createHostKeyCallback(opts DialOptions) (ssh.HostKeyCallback, error) {
if opts.InsecureSkipVerify {
return ssh.InsecureIgnoreHostKey(), nil // #nosec G106 - User explicitly requested insecure mode
}
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
return nil

View File

@@ -7,29 +7,36 @@ import (
"io"
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"testing"
"time"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/ssh"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
cleanupTestUsers()
testutil.CleanupTestUsers()
os.Exit(code)
}
@@ -39,19 +46,14 @@ func TestSSHClient_DialWithKey(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create and start server
server := sshserver.New(hostKey)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
@@ -62,8 +64,10 @@ func TestSSHClient_DialWithKey(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client.Close()
@@ -75,7 +79,7 @@ func TestSSHClient_DialWithKey(t *testing.T) {
}
func TestSSHClient_CommandExecution(t *testing.T) {
if runtime.GOOS == "windows" && isCI() {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows command execution tests in CI due to S4U authentication issues")
}
@@ -129,20 +133,16 @@ func TestSSHClient_ConnectionHandling(t *testing.T) {
}()
// Generate client key for multiple connections
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("multi-peer", string(clientPubKey))
require.NoError(t, err)
const numClients = 3
clients := make([]*Client, numClients)
for i := 0; i < numClients; i++ {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i))
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, fmt.Sprintf("%s-%d", currentUser, i), DialOptions{
InsecureSkipVerify: true,
})
cancel()
require.NoError(t, err, "Client %d should connect successfully", i)
clients[i] = client
@@ -161,19 +161,14 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
require.NoError(t, err)
}()
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("cancel-peer", string(clientPubKey))
require.NoError(t, err)
t.Run("connection with short timeout", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 1*time.Millisecond)
defer cancel()
currentUser := getCurrentUsername()
_, err = DialInsecure(ctx, serverAddr, currentUser)
currentUser := testutil.GetTestUsername(t)
_, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
if err != nil {
// Check for actual timeout-related errors rather than string matching
assert.True(t,
@@ -187,8 +182,10 @@ func TestSSHClient_ContextCancellation(t *testing.T) {
t.Run("command execution cancellation", func(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
@@ -214,7 +211,11 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
server := sshserver.New(hostKey)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
serverAddr := sshserver.StartTestServer(t, server)
@@ -226,10 +227,12 @@ func TestSSHClient_NoAuthMode(t *testing.T) {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
currentUser := testutil.GetTestUsername(t)
t.Run("any key succeeds in no-auth mode", func(t *testing.T) {
client, err := DialInsecure(ctx, serverAddr, currentUser)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
assert.NoError(t, err)
if client != nil {
require.NoError(t, client.Close(), "Client should close without error")
@@ -282,24 +285,22 @@ func setupTestSSHServerAndClient(t *testing.T) (*sshserver.Server, string, *Clie
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := sshserver.New(hostKey)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server)
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
currentUser := getCurrentUsername()
client, err := DialInsecure(ctx, serverAddr, currentUser)
currentUser := testutil.GetTestUsername(t)
client, err := Dial(ctx, serverAddr, currentUser, DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
return server, serverAddr, client
@@ -361,18 +362,14 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := sshserver.New(hostKey)
serverConfig := &sshserver.Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := sshserver.New(serverConfig)
server.SetAllowLocalPortForwarding(true)
server.SetAllowRootLogin(true) // Allow root/admin login for tests
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverAddr := sshserver.StartTestServer(t, server)
defer func() {
err := server.Stop()
@@ -387,11 +384,13 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
require.NoError(t, err)
// Skip if running as system account that can't do port forwarding
if isSystemAccount(realUser) {
if testutil.IsSystemAccount(realUser) {
t.Skipf("Skipping port forwarding test - running as system account: %s", realUser)
}
client, err := DialInsecure(ctx, serverAddr, realUser)
client, err := Dial(ctx, serverAddr, realUser, DialOptions{
InsecureSkipVerify: true, // Skip host key verification for test
})
require.NoError(t, err)
defer func() {
if err := client.Close(); err != nil {
@@ -478,180 +477,6 @@ func TestSSHClient_PortForwardingDataTransfer(t *testing.T) {
assert.Equal(t, expectedResponse, string(response))
}
// getCurrentUsername returns the current username for SSH connections
func getCurrentUsername() string {
if runtime.GOOS == "windows" {
if currentUser, err := user.Current(); err == nil {
// Check if this is a system account that can't authenticate
if isSystemAccount(currentUser.Username) {
// In CI environments, create a test user; otherwise try Administrator
if isCI() {
if testUser := getOrCreateTestUser(); testUser != "" {
return testUser
}
} else {
// Try Administrator first for local development
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := getOrCreateTestUser(); testUser != "" {
return testUser
}
}
}
// On Windows, return the full domain\username for proper authentication
return currentUser.Username
}
}
if username := os.Getenv("USER"); username != "" {
return username
}
if currentUser, err := user.Current(); err == nil {
return currentUser.Username
}
return "test-user"
}
// isCI checks if we're running in GitHub Actions CI
func isCI() bool {
// Check standard CI environment variables
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// getOrCreateTestUser creates a test user on Windows if needed
func getOrCreateTestUser() string {
testUsername := "netbird-test-user"
// Check if user already exists
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
// Try to create the user using PowerShell
if createWindowsTestUser(testUsername) {
// Register cleanup for the test user
registerTestUserCleanup(testUsername)
return testUsername
}
return ""
}
var createdTestUsers = make(map[string]bool)
var testUsersToCleanup []string
// registerTestUserCleanup registers a test user for cleanup
func registerTestUserCleanup(username string) {
if !createdTestUsers[username] {
createdTestUsers[username] = true
testUsersToCleanup = append(testUsersToCleanup, username)
}
}
// cleanupTestUsers removes all created test users
func cleanupTestUsers() {
for _, username := range testUsersToCleanup {
removeWindowsTestUser(username)
}
testUsersToCleanup = nil
createdTestUsers = make(map[string]bool)
}
// removeWindowsTestUser removes a local user on Windows using PowerShell
func removeWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
// PowerShell command to remove a local user
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// createWindowsTestUser creates a local user on Windows using PowerShell
func createWindowsTestUser(username string) bool {
if runtime.GOOS != "windows" {
return false
}
// PowerShell command to create a local user
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
log.Printf("Test user creation result: %s", string(output))
return true
}
// isSystemAccount checks if the user is a system account that can't authenticate
func isSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
// getRealCurrentUser returns the actual current user (not test user) for features like port forwarding
func getRealCurrentUser() (string, error) {
if runtime.GOOS == "windows" {

167
client/ssh/common.go Normal file
View File

@@ -0,0 +1,167 @@
package ssh
import (
"bytes"
"context"
"errors"
"fmt"
"io"
"net"
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/proto"
)
const (
NetBirdSSHConfigFile = "99-netbird.conf"
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
WindowsSSHConfigDir = "ssh/ssh_config.d"
)
var (
// ErrPeerNotFound indicates the peer was not found in the network
ErrPeerNotFound = errors.New("peer not found in network")
// ErrNoStoredKey indicates the peer has no stored SSH host key
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
)
// HostKeyVerifier provides SSH host key verification
type HostKeyVerifier interface {
VerifySSHHostKey(peerAddress string, key []byte) error
}
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
type DaemonHostKeyVerifier struct {
client proto.DaemonServiceClient
}
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
return &DaemonHostKeyVerifier{
client: client,
}
}
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: peerAddress,
})
if err != nil {
return err
}
if !response.GetFound() {
return ErrPeerNotFound
}
storedKeyData := response.GetSshHostKey()
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
}
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool) (string, error) {
authResponse, err := client.RequestJWTAuth(ctx, &proto.RequestJWTAuthRequest{})
if err != nil {
return "", fmt.Errorf("request JWT auth: %w", err)
}
if useCache && authResponse.CachedToken != "" {
log.Debug("Using cached authentication token")
return authResponse.CachedToken, nil
}
if stderr != nil {
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
if authResponse.UserCode != "" {
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
}
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
}
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
DeviceCode: authResponse.DeviceCode,
UserCode: authResponse.UserCode,
})
if err != nil {
return "", fmt.Errorf("wait for JWT token: %w", err)
}
if stdout != nil {
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
}
return tokenResponse.Token, nil
}
// VerifyHostKey verifies an SSH host key against stored peer key data.
// Returns nil only if the presented key matches the stored key.
// Returns ErrNoStoredKey if storedKeyData is empty.
// Returns an error if the keys don't match or if parsing fails.
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
if len(storedKeyData) == 0 {
return ErrNoStoredKey
}
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
if err != nil {
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
}
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
}
return nil
}
// AddJWTAuth prepends JWT password authentication to existing auth methods.
// This ensures JWT auth is tried first while preserving any existing auth methods.
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
configWithJWT := *config
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
return &configWithJWT
}
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
// It tries multiple addresses (hostname, IP) for the peer before failing.
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
addresses := buildAddressList(hostname, remote)
presentedKey := key.Marshal()
for _, addr := range addresses {
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
if errors.Is(err, ErrPeerNotFound) {
// Try other addresses for this peer
continue
}
return err
}
// Verified
return nil
}
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
}
}
// buildAddressList creates a list of addresses to check for host key verification.
// It includes the original hostname and extracts the host part from the remote address if different.
func buildAddressList(hostname string, remote net.Addr) []string {
addresses := []string{hostname}
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
if host != hostname {
addresses = append(addresses, host)
}
}
return addresses
}

View File

@@ -1,7 +1,6 @@
package config
import (
"bufio"
"context"
"fmt"
"os"
@@ -12,50 +11,41 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
const (
// EnvDisableSSHConfig is the environment variable to disable SSH config management
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
// EnvForceSSHConfig is the environment variable to force SSH config generation even with many peers
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
// MaxPeersForSSHConfig is the default maximum number of peers before SSH config generation is disabled
MaxPeersForSSHConfig = 200
// fileWriteTimeout is the timeout for file write operations
fileWriteTimeout = 2 * time.Second
)
// isSSHConfigDisabled checks if SSH config management is disabled via environment variable
func isSSHConfigDisabled() bool {
value := os.Getenv(EnvDisableSSHConfig)
if value == "" {
return false
}
// Parse as boolean, default to true if non-empty but invalid
disabled, err := strconv.ParseBool(value)
if err != nil {
// If not a valid boolean, treat any non-empty value as true
return true
}
return disabled
}
// isSSHConfigForced checks if SSH config generation is forced via environment variable
func isSSHConfigForced() bool {
value := os.Getenv(EnvForceSSHConfig)
if value == "" {
return false
}
// Parse as boolean, default to true if non-empty but invalid
forced, err := strconv.ParseBool(value)
if err != nil {
// If not a valid boolean, treat any non-empty value as true
return true
}
return forced
@@ -92,85 +82,55 @@ func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error
}
}
// writeFileOperationWithTimeout performs a file operation with timeout
func writeFileOperationWithTimeout(filename string, operation func() error) error {
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
defer cancel()
done := make(chan error, 1)
go func() {
done <- operation()
}()
select {
case err := <-done:
return err
case <-ctx.Done():
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
}
}
// Manager handles SSH client configuration for NetBird peers
type Manager struct {
sshConfigDir string
sshConfigFile string
knownHostsDir string
knownHostsFile string
userKnownHosts string
sshConfigDir string
sshConfigFile string
}
// PeerHostKey represents a peer's SSH host key information
type PeerHostKey struct {
// PeerSSHInfo represents a peer's SSH configuration information
type PeerSSHInfo struct {
Hostname string
IP string
FQDN string
HostKey ssh.PublicKey
}
// NewManager creates a new SSH config manager
func NewManager() *Manager {
sshConfigDir, knownHostsDir := getSystemSSHPaths()
// New creates a new SSH config manager
func New() *Manager {
sshConfigDir := getSystemSSHConfigDir()
return &Manager{
sshConfigDir: sshConfigDir,
sshConfigFile: "99-netbird.conf",
knownHostsDir: knownHostsDir,
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
sshConfigDir: sshConfigDir,
sshConfigFile: nbssh.NetBirdSSHConfigFile,
}
}
// getSystemSSHPaths returns platform-specific SSH configuration paths
func getSystemSSHPaths() (configDir, knownHostsDir string) {
switch runtime.GOOS {
case "windows":
configDir, knownHostsDir = getWindowsSSHPaths()
default:
// Unix-like systems (Linux, macOS, etc.)
configDir = "/etc/ssh/ssh_config.d"
knownHostsDir = "/etc/ssh/ssh_known_hosts.d"
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
func getSystemSSHConfigDir() string {
if runtime.GOOS == "windows" {
return getWindowsSSHConfigDir()
}
return configDir, knownHostsDir
return nbssh.UnixSSHConfigDir
}
func getWindowsSSHPaths() (configDir, knownHostsDir string) {
func getWindowsSSHConfigDir() string {
programData := os.Getenv("PROGRAMDATA")
if programData == "" {
programData = `C:\ProgramData`
}
configDir = filepath.Join(programData, "ssh", "ssh_config.d")
knownHostsDir = filepath.Join(programData, "ssh", "ssh_known_hosts.d")
return configDir, knownHostsDir
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
}
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
func (m *Manager) SetupSSHClientConfig(peerKeys []PeerHostKey) error {
if !shouldGenerateSSHConfig(len(peerKeys)) {
m.logSkipReason(len(peerKeys))
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
if !shouldGenerateSSHConfig(len(peers)) {
m.logSkipReason(len(peers))
return nil
}
knownHostsPath := m.getKnownHostsPath()
sshConfig := m.buildSSHConfig(peerKeys, knownHostsPath)
sshConfig, err := m.buildSSHConfig(peers)
if err != nil {
return fmt.Errorf("build SSH config: %w", err)
}
return m.writeSSHConfig(sshConfig)
}
@@ -183,21 +143,24 @@ func (m *Manager) logSkipReason(peerCount int) {
}
}
func (m *Manager) getKnownHostsPath() string {
knownHostsPath, err := m.setupKnownHostsFile()
if err != nil {
log.Warnf("Failed to setup known_hosts file: %v", err)
return "/dev/null"
}
return knownHostsPath
}
func (m *Manager) buildSSHConfig(peerKeys []PeerHostKey, knownHostsPath string) string {
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
sshConfig := m.buildConfigHeader()
for _, peer := range peerKeys {
sshConfig += m.buildPeerConfig(peer, knownHostsPath)
var allHostPatterns []string
for _, peer := range peers {
hostPatterns := m.buildHostPatterns(peer)
allHostPatterns = append(allHostPatterns, hostPatterns...)
}
return sshConfig
if len(allHostPatterns) > 0 {
peerConfig, err := m.buildPeerConfig(allHostPatterns)
if err != nil {
return "", err
}
sshConfig += peerConfig
}
return sshConfig, nil
}
func (m *Manager) buildConfigHeader() string {
@@ -209,25 +172,49 @@ func (m *Manager) buildConfigHeader() string {
"#\n\n"
}
func (m *Manager) buildPeerConfig(peer PeerHostKey, knownHostsPath string) string {
hostPatterns := m.buildHostPatterns(peer)
if len(hostPatterns) == 0 {
return ""
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
uniquePatterns := make(map[string]bool)
var deduplicatedPatterns []string
for _, pattern := range allHostPatterns {
if !uniquePatterns[pattern] {
uniquePatterns[pattern] = true
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
}
}
hostLine := strings.Join(hostPatterns, " ")
execPath, err := m.getNetBirdExecutablePath()
if err != nil {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
config += " # NetBird peer-specific configuration\n"
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += m.buildHostKeyConfig(knownHostsPath)
config += " LogLevel ERROR\n\n"
return config
if runtime.GOOS == "windows" {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
} else {
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
}
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
}
func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
var hostPatterns []string
if peer.IP != "" {
hostPatterns = append(hostPatterns, peer.IP)
@@ -241,280 +228,55 @@ func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
return hostPatterns
}
func (m *Manager) buildHostKeyConfig(knownHostsPath string) string {
if knownHostsPath == "/dev/null" {
return " StrictHostKeyChecking no\n" +
" UserKnownHostsFile /dev/null\n"
}
return " StrictHostKeyChecking yes\n" +
fmt.Sprintf(" UserKnownHostsFile %s\n", knownHostsPath)
}
func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
log.Warnf("Failed to create SSH config directory %s: %v", m.sshConfigDir, err)
return m.setupUserConfig(sshConfig)
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
}
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
log.Warnf("Failed to write SSH config file %s: %v", sshConfigPath, err)
return m.setupUserConfig(sshConfig)
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
return nil
}
// setupUserConfig creates SSH config in user's directory as fallback
func (m *Manager) setupUserConfig(sshConfig string) error {
homeDir, err := os.UserHomeDir()
if err != nil {
return fmt.Errorf("get user home directory: %w", err)
}
userSSHDir := filepath.Join(homeDir, ".ssh")
userConfigPath := filepath.Join(userSSHDir, "config")
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
return fmt.Errorf("create user SSH directory: %w", err)
}
// Check if NetBird config already exists in user config
exists, err := m.configExists(userConfigPath)
if err != nil {
return fmt.Errorf("check existing config: %w", err)
}
if exists {
log.Debugf("NetBird SSH config already exists in %s", userConfigPath)
return nil
}
// Append NetBird config to user's SSH config with timeout
if err := writeFileOperationWithTimeout(userConfigPath, func() error {
file, err := os.OpenFile(userConfigPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return fmt.Errorf("open user SSH config: %w", err)
}
defer func() {
if err := file.Close(); err != nil {
log.Debugf("user SSH config file close error: %v", err)
}
}()
if _, err := fmt.Fprintf(file, "\n%s", sshConfig); err != nil {
return fmt.Errorf("write to user SSH config: %w", err)
}
return nil
}); err != nil {
return err
}
log.Infof("Added NetBird SSH config to user config: %s", userConfigPath)
return nil
}
// configExists checks if NetBird SSH config already exists
func (m *Manager) configExists(configPath string) (bool, error) {
file, err := os.Open(configPath)
if err != nil {
if os.IsNotExist(err) {
return false, nil
}
return false, fmt.Errorf("open SSH config file: %w", err)
}
defer file.Close()
scanner := bufio.NewScanner(file)
for scanner.Scan() {
line := strings.TrimSpace(scanner.Text())
if strings.Contains(line, "NetBird SSH client configuration") {
return true, nil
}
}
return false, scanner.Err()
}
// RemoveSSHClientConfig removes NetBird SSH configuration
func (m *Manager) RemoveSSHClientConfig() error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
// Remove system-wide config if it exists
if err := os.Remove(sshConfigPath); err != nil && !os.IsNotExist(err) {
log.Warnf("Failed to remove system SSH config %s: %v", sshConfigPath, err)
} else if err == nil {
err := os.Remove(sshConfigPath)
if err != nil && !os.IsNotExist(err) {
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
}
if err == nil {
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
}
// Also try to clean up user config
homeDir, err := os.UserHomeDir()
if err != nil {
log.Debugf("failed to get user home directory: %v", err)
return nil
}
userConfigPath := filepath.Join(homeDir, ".ssh", "config")
if err := m.removeFromUserConfig(userConfigPath); err != nil {
log.Warnf("Failed to clean user SSH config: %v", err)
}
return nil
}
// removeFromUserConfig removes NetBird section from user's SSH config
func (m *Manager) removeFromUserConfig(configPath string) error {
// This is complex to implement safely, so for now just log
// In practice, the system-wide config takes precedence anyway
log.Debugf("NetBird SSH config cleanup from user config not implemented")
return nil
}
// setupKnownHostsFile creates and returns the path to NetBird known_hosts file
func (m *Manager) setupKnownHostsFile() (string, error) {
// Try system-wide known_hosts first
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
if err := os.MkdirAll(m.knownHostsDir, 0755); err == nil {
// Create empty file if it doesn't exist
if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
if err := writeFileWithTimeout(knownHostsPath, []byte("# NetBird SSH known hosts\n"), 0644); err == nil {
log.Debugf("Created NetBird known_hosts file: %s", knownHostsPath)
return knownHostsPath, nil
}
} else if err == nil {
return knownHostsPath, nil
}
}
// Fallback to user directory
homeDir, err := os.UserHomeDir()
func (m *Manager) getNetBirdExecutablePath() (string, error) {
execPath, err := os.Executable()
if err != nil {
return "", fmt.Errorf("get user home directory: %w", err)
return "", fmt.Errorf("retrieve executable path: %w", err)
}
userSSHDir := filepath.Join(homeDir, ".ssh")
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
return "", fmt.Errorf("create user SSH directory: %w", err)
}
userKnownHostsPath := filepath.Join(userSSHDir, m.userKnownHosts)
if _, err := os.Stat(userKnownHostsPath); os.IsNotExist(err) {
if err := writeFileWithTimeout(userKnownHostsPath, []byte("# NetBird SSH known hosts\n"), 0600); err != nil {
return "", fmt.Errorf("create user known_hosts file: %w", err)
}
log.Debugf("Created NetBird user known_hosts file: %s", userKnownHostsPath)
}
return userKnownHostsPath, nil
}
// UpdatePeerHostKeys updates the known_hosts file with peer host keys
func (m *Manager) UpdatePeerHostKeys(peerKeys []PeerHostKey) error {
peerCount := len(peerKeys)
// Check if SSH config should be generated
if !shouldGenerateSSHConfig(peerCount) {
if isSSHConfigDisabled() {
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
} else {
log.Infof("SSH known_hosts update skipped: too many peers (%d > %d). Use %s=true to force.",
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
}
return nil
}
knownHostsPath, err := m.setupKnownHostsFile()
realPath, err := filepath.EvalSymlinks(execPath)
if err != nil {
return fmt.Errorf("setup known_hosts file: %w", err)
log.Debugf("symlink resolution failed: %v", err)
return execPath, nil
}
// Create updated known_hosts content - NetBird file should only contain NetBird entries
var updatedContent strings.Builder
updatedContent.WriteString("# NetBird SSH known hosts\n")
updatedContent.WriteString("# Generated automatically - do not edit manually\n\n")
// Add new NetBird entries - one entry per peer with all hostnames
for _, peerKey := range peerKeys {
entry := m.formatKnownHostsEntry(peerKey)
updatedContent.WriteString(entry)
updatedContent.WriteString("\n")
}
// Write updated content
if err := writeFileWithTimeout(knownHostsPath, []byte(updatedContent.String()), 0644); err != nil {
return fmt.Errorf("write known_hosts file: %w", err)
}
log.Debugf("Updated NetBird known_hosts with %d peer keys: %s", len(peerKeys), knownHostsPath)
return nil
return realPath, nil
}
// formatKnownHostsEntry formats a peer host key as a known_hosts entry
func (m *Manager) formatKnownHostsEntry(peerKey PeerHostKey) string {
hostnames := m.getHostnameVariants(peerKey)
hostnameList := strings.Join(hostnames, ",")
keyString := string(ssh.MarshalAuthorizedKey(peerKey.HostKey))
keyString = strings.TrimSpace(keyString)
return fmt.Sprintf("%s %s", hostnameList, keyString)
// GetSSHConfigDir returns the SSH config directory path
func (m *Manager) GetSSHConfigDir() string {
return m.sshConfigDir
}
// getHostnameVariants returns all possible hostname variants for a peer
func (m *Manager) getHostnameVariants(peerKey PeerHostKey) []string {
var hostnames []string
// Add IP address
if peerKey.IP != "" {
hostnames = append(hostnames, peerKey.IP)
}
// Add FQDN
if peerKey.FQDN != "" {
hostnames = append(hostnames, peerKey.FQDN)
}
// Add hostname if different from FQDN
if peerKey.Hostname != "" && peerKey.Hostname != peerKey.FQDN {
hostnames = append(hostnames, peerKey.Hostname)
}
// Add bracketed IP for non-standard ports (SSH standard)
if peerKey.IP != "" {
hostnames = append(hostnames, fmt.Sprintf("[%s]:22", peerKey.IP))
hostnames = append(hostnames, fmt.Sprintf("[%s]:22022", peerKey.IP))
}
return hostnames
}
// GetKnownHostsPath returns the path to the NetBird known_hosts file
func (m *Manager) GetKnownHostsPath() (string, error) {
return m.setupKnownHostsFile()
}
// RemoveKnownHostsFile removes the NetBird known_hosts file
func (m *Manager) RemoveKnownHostsFile() error {
// Remove system-wide known_hosts if it exists
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
if err := os.Remove(knownHostsPath); err != nil && !os.IsNotExist(err) {
log.Warnf("Failed to remove system known_hosts %s: %v", knownHostsPath, err)
} else if err == nil {
log.Infof("Removed NetBird known_hosts: %s", knownHostsPath)
}
// Also try to clean up user known_hosts
homeDir, err := os.UserHomeDir()
if err != nil {
log.Debugf("failed to get user home directory: %v", err)
return nil
}
userKnownHostsPath := filepath.Join(homeDir, ".ssh", m.userKnownHosts)
if err := os.Remove(userKnownHostsPath); err != nil && !os.IsNotExist(err) {
log.Warnf("Failed to remove user known_hosts %s: %v", userKnownHostsPath, err)
} else if err == nil {
log.Infof("Removed NetBird user known_hosts: %s", userKnownHostsPath)
}
return nil
// GetSSHConfigFile returns the SSH config file name
func (m *Manager) GetSSHConfigFile() string {
return m.sshConfigFile
}

View File

@@ -10,81 +10,8 @@ import (
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
func TestManager_UpdatePeerHostKeys(t *testing.T) {
// Create temporary directory for test
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
}
// Generate test host keys
hostKey1, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey1, err := ssh.ParsePrivateKey(hostKey1)
require.NoError(t, err)
hostKey2, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey2, err := ssh.ParsePrivateKey(hostKey2)
require.NoError(t, err)
// Create test peer host keys
peerKeys := []PeerHostKey{
{
Hostname: "peer1",
IP: "100.125.1.1",
FQDN: "peer1.nb.internal",
HostKey: pubKey1.PublicKey(),
},
{
Hostname: "peer2",
IP: "100.125.1.2",
FQDN: "peer2.nb.internal",
HostKey: pubKey2.PublicKey(),
},
}
// Test updating known_hosts
err = manager.UpdatePeerHostKeys(peerKeys)
require.NoError(t, err)
// Verify known_hosts file was created and contains entries
knownHostsPath, err := manager.GetKnownHostsPath()
require.NoError(t, err)
content, err := os.ReadFile(knownHostsPath)
require.NoError(t, err)
contentStr := string(content)
assert.Contains(t, contentStr, "100.125.1.1")
assert.Contains(t, contentStr, "100.125.1.2")
assert.Contains(t, contentStr, "peer1.nb.internal")
assert.Contains(t, contentStr, "peer2.nb.internal")
assert.Contains(t, contentStr, "[100.125.1.1]:22")
assert.Contains(t, contentStr, "[100.125.1.1]:22022")
// Test updating with empty list should preserve structure
err = manager.UpdatePeerHostKeys([]PeerHostKey{})
require.NoError(t, err)
content, err = os.ReadFile(knownHostsPath)
require.NoError(t, err)
assert.Contains(t, string(content), "# NetBird SSH known hosts")
}
func TestManager_SetupSSHClientConfig(t *testing.T) {
// Create temporary directory for test
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
@@ -93,15 +20,25 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
// Test SSH config generation with empty peer keys
err = manager.SetupSSHClientConfig(nil)
// Test SSH config generation with peers
peers := []PeerSSHInfo{
{
Hostname: "peer1",
IP: "100.125.1.1",
FQDN: "peer1.nb.internal",
},
{
Hostname: "peer2",
IP: "100.125.1.2",
FQDN: "peer2.nb.internal",
},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
// Read generated config
@@ -111,134 +48,39 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
configStr := string(content)
// Since we now use per-peer configurations instead of domain patterns,
// we should verify the basic SSH config structure exists
// Verify the basic SSH config structure exists
assert.Contains(t, configStr, "# NetBird SSH client configuration")
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
// Should not contain /dev/null since we have a proper known_hosts setup
assert.NotContains(t, configStr, "UserKnownHostsFile /dev/null")
}
// Check that peer hostnames are included
assert.Contains(t, configStr, "100.125.1.1")
assert.Contains(t, configStr, "100.125.1.2")
assert.Contains(t, configStr, "peer1.nb.internal")
assert.Contains(t, configStr, "peer2.nb.internal")
func TestManager_GetHostnameVariants(t *testing.T) {
manager := NewManager()
peerKey := PeerHostKey{
Hostname: "testpeer",
IP: "100.125.1.10",
FQDN: "testpeer.nb.internal",
HostKey: nil, // Not needed for this test
}
variants := manager.getHostnameVariants(peerKey)
expectedVariants := []string{
"100.125.1.10",
"testpeer.nb.internal",
"testpeer",
"[100.125.1.10]:22",
"[100.125.1.10]:22022",
}
assert.ElementsMatch(t, expectedVariants, variants)
}
func TestManager_FormatKnownHostsEntry(t *testing.T) {
manager := NewManager()
// Generate test key
hostKeyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
parsedKey, err := ssh.ParsePrivateKey(hostKeyPEM)
require.NoError(t, err)
peerKey := PeerHostKey{
Hostname: "testpeer",
IP: "100.125.1.10",
FQDN: "testpeer.nb.internal",
HostKey: parsedKey.PublicKey(),
}
entry := manager.formatKnownHostsEntry(peerKey)
// Should contain all hostname variants
assert.Contains(t, entry, "100.125.1.10")
assert.Contains(t, entry, "testpeer.nb.internal")
assert.Contains(t, entry, "testpeer")
assert.Contains(t, entry, "[100.125.1.10]:22")
assert.Contains(t, entry, "[100.125.1.10]:22022")
// Should contain the public key
keyString := string(ssh.MarshalAuthorizedKey(parsedKey.PublicKey()))
keyString = strings.TrimSpace(keyString)
assert.Contains(t, entry, keyString)
// Should be properly formatted (hostnames followed by key)
parts := strings.Fields(entry)
assert.GreaterOrEqual(t, len(parts), 2, "Entry should have hostnames and key parts")
}
func TestManager_DirectoryFallback(t *testing.T) {
// Create temporary directory for test where system dirs will fail
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
// Set HOME to temp directory to control user fallback
t.Setenv("HOME", tempDir)
// Create manager with non-writable system directories
// Use paths that will fail on all systems
var failPath string
// Check platform-specific UserKnownHostsFile
if runtime.GOOS == "windows" {
failPath = "NUL:" // Special device that can't be used as directory on Windows
assert.Contains(t, configStr, "UserKnownHostsFile NUL")
} else {
failPath = "/dev/null" // Special device that can't be used as directory on Unix
assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
}
manager := &Manager{
sshConfigDir: failPath + "/ssh_config.d", // Should fail
sshConfigFile: "99-netbird.conf",
knownHostsDir: failPath + "/ssh_known_hosts.d", // Should fail
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
}
// Should fall back to user directory
knownHostsPath, err := manager.setupKnownHostsFile()
require.NoError(t, err)
// Get the actual user home directory as determined by os.UserHomeDir()
userHome, err := os.UserHomeDir()
require.NoError(t, err)
expectedUserPath := filepath.Join(userHome, ".ssh", "known_hosts_netbird")
assert.Equal(t, expectedUserPath, knownHostsPath)
// Verify file was created
_, err = os.Stat(knownHostsPath)
require.NoError(t, err)
}
func TestGetSystemSSHPaths(t *testing.T) {
configDir, knownHostsDir := getSystemSSHPaths()
func TestGetSystemSSHConfigDir(t *testing.T) {
configDir := getSystemSSHConfigDir()
// Paths should not be empty
// Path should not be empty
assert.NotEmpty(t, configDir)
assert.NotEmpty(t, knownHostsDir)
// Should be absolute paths
// Should be an absolute path
assert.True(t, filepath.IsAbs(configDir))
assert.True(t, filepath.IsAbs(knownHostsDir))
// On Unix systems, should start with /etc
// On Windows, should contain ProgramData
if runtime.GOOS == "windows" {
assert.Contains(t, strings.ToLower(configDir), "programdata")
assert.Contains(t, strings.ToLower(knownHostsDir), "programdata")
} else {
assert.Contains(t, configDir, "/etc/ssh")
assert.Contains(t, knownHostsDir, "/etc/ssh")
}
}
@@ -250,46 +92,28 @@ func TestManager_PeerLimit(t *testing.T) {
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
// Generate many peer keys (more than limit)
var peerKeys []PeerHostKey
// Generate many peers (more than limit)
var peers []PeerSSHInfo
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey, err := ssh.ParsePrivateKey(hostKey)
require.NoError(t, err)
peerKeys = append(peerKeys, PeerHostKey{
peers = append(peers, PeerSSHInfo{
Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
HostKey: pubKey.PublicKey(),
})
}
// Test that SSH config generation is skipped when too many peers
err = manager.SetupSSHClientConfig(peerKeys)
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
// Config should not be created due to peer limit
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
_, err = os.Stat(configPath)
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
// Test that known_hosts update is also skipped
err = manager.UpdatePeerHostKeys(peerKeys)
require.NoError(t, err)
// Known hosts should not be created due to peer limit
knownHostsPath := filepath.Join(manager.knownHostsDir, manager.knownHostsFile)
_, err = os.Stat(knownHostsPath)
assert.True(t, os.IsNotExist(err), "Known hosts should not be created with too many peers")
}
func TestManager_ForcedSSHConfig(t *testing.T) {
@@ -303,31 +127,22 @@ func TestManager_ForcedSSHConfig(t *testing.T) {
// Override manager paths to use temp directory
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
knownHostsFile: "99-netbird",
userKnownHosts: "known_hosts_netbird",
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
// Generate many peer keys (more than limit)
var peerKeys []PeerHostKey
// Generate many peers (more than limit)
var peers []PeerSSHInfo
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
pubKey, err := ssh.ParsePrivateKey(hostKey)
require.NoError(t, err)
peerKeys = append(peerKeys, PeerHostKey{
peers = append(peers, PeerSSHInfo{
Hostname: fmt.Sprintf("peer%d", i),
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
HostKey: pubKey.PublicKey(),
})
}
// Test that SSH config generation is forced despite many peers
err = manager.SetupSSHClientConfig(peerKeys)
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
// Config should be created despite peer limit due to force flag

View File

@@ -0,0 +1,22 @@
package config
// ShutdownState represents SSH configuration state that needs to be cleaned up.
type ShutdownState struct {
SSHConfigDir string
SSHConfigFile string
}
// Name returns the state name for the state manager.
func (s *ShutdownState) Name() string {
return "ssh_config_state"
}
// Cleanup removes SSH client configuration files.
func (s *ShutdownState) Cleanup() error {
manager := &Manager{
sshConfigDir: s.SSHConfigDir,
sshConfigFile: s.SSHConfigFile,
}
return manager.RemoveSSHClientConfig()
}

View File

@@ -0,0 +1,99 @@
package detection
import (
"bufio"
"context"
"net"
"strconv"
"strings"
"time"
log "github.com/sirupsen/logrus"
)
const (
// ServerIdentifier is the base response for NetBird SSH servers
ServerIdentifier = "NetBird-SSH-Server"
// ProxyIdentifier is the base response for NetBird SSH proxy
ProxyIdentifier = "NetBird-SSH-Proxy"
// JWTRequiredMarker is appended to responses when JWT is required
JWTRequiredMarker = "NetBird-JWT-Required"
// Timeout is the timeout for SSH server detection
Timeout = 5 * time.Second
)
type ServerType string
const (
ServerTypeNetBirdJWT ServerType = "netbird-jwt"
ServerTypeNetBirdNoJWT ServerType = "netbird-no-jwt"
ServerTypeRegular ServerType = "regular"
)
// Dialer provides network connection capabilities
type Dialer interface {
DialContext(ctx context.Context, network, address string) (net.Conn, error)
}
// RequiresJWT checks if the server type requires JWT authentication
func (s ServerType) RequiresJWT() bool {
return s == ServerTypeNetBirdJWT
}
// ExitCode returns the exit code for the detect command
func (s ServerType) ExitCode() int {
switch s {
case ServerTypeNetBirdJWT:
return 0
case ServerTypeNetBirdNoJWT:
return 1
case ServerTypeRegular:
return 2
default:
return 2
}
}
// DetectSSHServerType detects SSH server type using the provided dialer
func DetectSSHServerType(ctx context.Context, dialer Dialer, host string, port int) (ServerType, error) {
targetAddr := net.JoinHostPort(host, strconv.Itoa(port))
conn, err := dialer.DialContext(ctx, "tcp", targetAddr)
if err != nil {
log.Debugf("SSH connection failed for detection: %v", err)
return ServerTypeRegular, nil
}
defer conn.Close()
if err := conn.SetReadDeadline(time.Now().Add(Timeout)); err != nil {
log.Debugf("set read deadline: %v", err)
return ServerTypeRegular, nil
}
reader := bufio.NewReader(conn)
serverBanner, err := reader.ReadString('\n')
if err != nil {
log.Debugf("read SSH banner: %v", err)
return ServerTypeRegular, nil
}
serverBanner = strings.TrimSpace(serverBanner)
log.Debugf("SSH server banner: %s", serverBanner)
if !strings.HasPrefix(serverBanner, "SSH-") {
log.Debugf("Invalid SSH banner")
return ServerTypeRegular, nil
}
if !strings.Contains(serverBanner, ServerIdentifier) {
log.Debugf("Server banner does not contain identifier '%s'", ServerIdentifier)
return ServerTypeRegular, nil
}
if strings.Contains(serverBanner, JWTRequiredMarker) {
return ServerTypeNetBirdJWT, nil
}
return ServerTypeNetBirdNoJWT, nil
}

359
client/ssh/proxy/proxy.go Normal file
View File

@@ -0,0 +1,359 @@
package proxy
import (
"context"
"errors"
"fmt"
"io"
"net"
"os"
"strconv"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/version"
)
const (
// sshConnectionTimeout is the timeout for SSH TCP connection establishment
sshConnectionTimeout = 120 * time.Second
// sshHandshakeTimeout is the timeout for SSH handshake completion
sshHandshakeTimeout = 30 * time.Second
jwtAuthErrorMsg = "JWT authentication: %w"
)
type SSHProxy struct {
daemonAddr string
targetHost string
targetPort int
stderr io.Writer
daemonClient proto.DaemonServiceClient
}
func New(daemonAddr, targetHost string, targetPort int, stderr io.Writer) (*SSHProxy, error) {
grpcAddr := strings.TrimPrefix(daemonAddr, "tcp://")
grpcConn, err := grpc.NewClient(grpcAddr, grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
return nil, fmt.Errorf("connect to daemon: %w", err)
}
return &SSHProxy{
daemonAddr: daemonAddr,
targetHost: targetHost,
targetPort: targetPort,
stderr: stderr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}, nil
}
func (p *SSHProxy) Connect(ctx context.Context) error {
jwtToken, err := nbssh.RequestJWTToken(ctx, p.daemonClient, nil, p.stderr, true)
if err != nil {
return fmt.Errorf(jwtAuthErrorMsg, err)
}
return p.runProxySSHServer(ctx, jwtToken)
}
func (p *SSHProxy) runProxySSHServer(ctx context.Context, jwtToken string) error {
serverVersion := fmt.Sprintf("%s-%s", detection.ProxyIdentifier, version.NetbirdVersion())
sshServer := &ssh.Server{
Handler: func(s ssh.Session) {
p.handleSSHSession(ctx, s, jwtToken)
},
ChannelHandlers: map[string]ssh.ChannelHandler{
"session": ssh.DefaultSessionHandler,
"direct-tcpip": p.directTCPIPHandler,
},
SubsystemHandlers: map[string]ssh.SubsystemHandler{
"sftp": func(s ssh.Session) {
p.sftpSubsystemHandler(s, jwtToken)
},
},
RequestHandlers: map[string]ssh.RequestHandler{
"tcpip-forward": p.tcpipForwardHandler,
"cancel-tcpip-forward": p.cancelTcpipForwardHandler,
},
Version: serverVersion,
}
hostKey, err := generateHostKey()
if err != nil {
return fmt.Errorf("generate host key: %w", err)
}
sshServer.HostSigners = []ssh.Signer{hostKey}
conn := &stdioConn{
stdin: os.Stdin,
stdout: os.Stdout,
}
sshServer.HandleConn(conn)
return nil
}
func (p *SSHProxy) handleSSHSession(ctx context.Context, session ssh.Session, jwtToken string) {
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, session.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "SSH connection to NetBird server failed: %v\n", err)
return
}
defer func() { _ = sshClient.Close() }()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(p.stderr, "create server session: %v\n", err)
return
}
defer func() { _ = serverSession.Close() }()
serverSession.Stdin = session
serverSession.Stdout = session
serverSession.Stderr = session.Stderr()
ptyReq, winCh, isPty := session.Pty()
if isPty {
_ = serverSession.RequestPty(ptyReq.Term, ptyReq.Window.Width, ptyReq.Window.Height, nil)
go func() {
for win := range winCh {
_ = serverSession.WindowChange(win.Height, win.Width)
}
}()
}
if len(session.Command()) > 0 {
_ = serverSession.Run(strings.Join(session.Command(), " "))
return
}
if err = serverSession.Shell(); err == nil {
_ = serverSession.Wait()
}
}
func generateHostKey() (ssh.Signer, error) {
keyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
return nil, fmt.Errorf("generate ED25519 key: %w", err)
}
signer, err := cryptossh.ParsePrivateKey(keyPEM)
if err != nil {
return nil, fmt.Errorf("parse private key: %w", err)
}
return signer, nil
}
type stdioConn struct {
stdin io.Reader
stdout io.Writer
closed bool
mu sync.Mutex
}
func (c *stdioConn) Read(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.EOF
}
c.mu.Unlock()
return c.stdin.Read(b)
}
func (c *stdioConn) Write(b []byte) (n int, err error) {
c.mu.Lock()
if c.closed {
c.mu.Unlock()
return 0, io.ErrClosedPipe
}
c.mu.Unlock()
return c.stdout.Write(b)
}
func (c *stdioConn) Close() error {
c.mu.Lock()
defer c.mu.Unlock()
c.closed = true
return nil
}
func (c *stdioConn) LocalAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) RemoteAddr() net.Addr {
return &net.UnixAddr{Name: "stdio", Net: "unix"}
}
func (c *stdioConn) SetDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetReadDeadline(_ time.Time) error {
return nil
}
func (c *stdioConn) SetWriteDeadline(_ time.Time) error {
return nil
}
func (p *SSHProxy) directTCPIPHandler(_ *ssh.Server, _ *cryptossh.ServerConn, newChan cryptossh.NewChannel, _ ssh.Context) {
_ = newChan.Reject(cryptossh.Prohibited, "port forwarding not supported in proxy")
}
func (p *SSHProxy) sftpSubsystemHandler(s ssh.Session, jwtToken string) {
ctx, cancel := context.WithCancel(s.Context())
defer cancel()
targetAddr := net.JoinHostPort(p.targetHost, strconv.Itoa(p.targetPort))
sshClient, err := p.dialBackend(ctx, targetAddr, s.User(), jwtToken)
if err != nil {
_, _ = fmt.Fprintf(s, "SSH connection failed: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := sshClient.Close(); err != nil {
log.Debugf("close SSH client: %v", err)
}
}()
serverSession, err := sshClient.NewSession()
if err != nil {
_, _ = fmt.Fprintf(s, "create server session: %v\n", err)
_ = s.Exit(1)
return
}
defer func() {
if err := serverSession.Close(); err != nil {
log.Debugf("close server session: %v", err)
}
}()
stdin, stdout, err := p.setupSFTPPipes(serverSession)
if err != nil {
log.Debugf("setup SFTP pipes: %v", err)
_ = s.Exit(1)
return
}
if err := serverSession.RequestSubsystem("sftp"); err != nil {
_, _ = fmt.Fprintf(s, "SFTP subsystem request failed: %v\n", err)
_ = s.Exit(1)
return
}
p.runSFTPBridge(ctx, s, stdin, stdout, serverSession)
}
func (p *SSHProxy) setupSFTPPipes(serverSession *cryptossh.Session) (io.WriteCloser, io.Reader, error) {
stdin, err := serverSession.StdinPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdin pipe: %w", err)
}
stdout, err := serverSession.StdoutPipe()
if err != nil {
return nil, nil, fmt.Errorf("get stdout pipe: %w", err)
}
return stdin, stdout, nil
}
func (p *SSHProxy) runSFTPBridge(ctx context.Context, s ssh.Session, stdin io.WriteCloser, stdout io.Reader, serverSession *cryptossh.Session) {
copyErrCh := make(chan error, 2)
go func() {
_, err := io.Copy(stdin, s)
if err != nil {
log.Debugf("SFTP client to server copy: %v", err)
}
if err := stdin.Close(); err != nil {
log.Debugf("close stdin: %v", err)
}
copyErrCh <- err
}()
go func() {
_, err := io.Copy(s, stdout)
if err != nil {
log.Debugf("SFTP server to client copy: %v", err)
}
copyErrCh <- err
}()
go func() {
<-ctx.Done()
if err := serverSession.Close(); err != nil {
log.Debugf("force close server session on context cancellation: %v", err)
}
}()
for i := 0; i < 2; i++ {
if err := <-copyErrCh; err != nil && !errors.Is(err, io.EOF) {
log.Debugf("SFTP copy error: %v", err)
}
}
if err := serverSession.Wait(); err != nil {
log.Debugf("SFTP session ended: %v", err)
}
}
func (p *SSHProxy) tcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return false, []byte("port forwarding not supported in proxy")
}
func (p *SSHProxy) cancelTcpipForwardHandler(_ ssh.Context, _ *ssh.Server, _ *cryptossh.Request) (bool, []byte) {
return true, nil
}
func (p *SSHProxy) dialBackend(ctx context.Context, addr, user, jwtToken string) (*cryptossh.Client, error) {
config := &cryptossh.ClientConfig{
User: user,
Auth: []cryptossh.AuthMethod{cryptossh.Password(jwtToken)},
Timeout: sshHandshakeTimeout,
HostKeyCallback: p.verifyHostKey,
}
dialer := &net.Dialer{
Timeout: sshConnectionTimeout,
}
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return nil, fmt.Errorf("connect to server: %w", err)
}
clientConn, chans, reqs, err := cryptossh.NewClientConn(conn, addr, config)
if err != nil {
_ = conn.Close()
return nil, fmt.Errorf("SSH handshake: %w", err)
}
return cryptossh.NewClient(clientConn, chans, reqs), nil
}
func (p *SSHProxy) verifyHostKey(hostname string, remote net.Addr, key cryptossh.PublicKey) error {
verifier := nbssh.NewDaemonHostKeyVerifier(p.daemonClient)
callback := nbssh.CreateHostKeyCallback(verifier)
return callback(hostname, remote, key)
}

View File

@@ -0,0 +1,361 @@
package proxy
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"fmt"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"os"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
"github.com/netbirdio/netbird/client/proto"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/server"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
)
func TestMain(m *testing.M) {
if len(os.Args) > 2 && os.Args[1] == "ssh" {
if os.Args[2] == "exec" {
if len(os.Args) > 3 {
cmd := os.Args[3]
if cmd == "echo" && len(os.Args) > 4 {
fmt.Fprintln(os.Stdout, os.Args[4])
os.Exit(0)
}
}
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' with args: %v - preventing infinite recursion\n", os.Args)
os.Exit(1)
}
}
code := m.Run()
testutil.CleanupTestUsers()
os.Exit(code)
}
func TestSSHProxy_verifyHostKey(t *testing.T) {
t.Run("calls daemon to verify host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
testKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testPubKey, err := nbssh.GeneratePublicKey(testKey)
require.NoError(t, err)
mockDaemon.setHostKey("test-host", testPubKey)
err = proxy.verifyHostKey("test-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, testPubKey))
assert.NoError(t, err)
})
t.Run("rejects unknown host key", func(t *testing.T) {
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
grpcConn, err := grpc.NewClient(mockDaemon.addr, grpc.WithTransportCredentials(insecure.NewCredentials()))
require.NoError(t, err)
defer func() { _ = grpcConn.Close() }()
proxy := &SSHProxy{
daemonAddr: mockDaemon.addr,
daemonClient: proto.NewDaemonServiceClient(grpcConn),
}
unknownKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
unknownPubKey, err := nbssh.GeneratePublicKey(unknownKey)
require.NoError(t, err)
err = proxy.verifyHostKey("unknown-host", &net.TCPAddr{IP: net.ParseIP("127.0.0.1"), Port: 22}, mustParsePublicKey(t, unknownPubKey))
assert.Error(t, err)
assert.Contains(t, err.Error(), "peer unknown-host not found in network")
})
}
func TestSSHProxy_Connect(t *testing.T) {
if testing.Short() {
t.Skip("Skipping integration test in short mode")
}
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
hostPubKey, err := nbssh.GeneratePublicKey(hostKey)
require.NoError(t, err)
serverConfig := &server.Config{
HostKeyPEM: hostKey,
JWT: &server.JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
},
}
sshServer := server.New(serverConfig)
sshServer.SetAllowRootLogin(true)
sshServerAddr := server.StartTestServer(t, sshServer)
defer func() { _ = sshServer.Stop() }()
mockDaemon := startMockDaemon(t)
defer mockDaemon.stop()
host, portStr, err := net.SplitHostPort(sshServerAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
mockDaemon.setHostKey(host, hostPubKey)
validToken := generateValidJWT(t, privateKey, issuer, audience)
mockDaemon.setJWTToken(validToken)
proxyInstance, err := New(mockDaemon.addr, host, port, nil)
require.NoError(t, err)
clientConn, proxyConn := net.Pipe()
defer func() { _ = clientConn.Close() }()
origStdin := os.Stdin
origStdout := os.Stdout
defer func() {
os.Stdin = origStdin
os.Stdout = origStdout
}()
stdinReader, stdinWriter, err := os.Pipe()
require.NoError(t, err)
stdoutReader, stdoutWriter, err := os.Pipe()
require.NoError(t, err)
os.Stdin = stdinReader
os.Stdout = stdoutWriter
go func() {
_, _ = io.Copy(stdinWriter, proxyConn)
}()
go func() {
_, _ = io.Copy(proxyConn, stdoutReader)
}()
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
connectErrCh := make(chan error, 1)
go func() {
connectErrCh <- proxyInstance.Connect(ctx)
}()
sshConfig := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 3 * time.Second,
}
sshClientConn, chans, reqs, err := cryptossh.NewClientConn(clientConn, "test", sshConfig)
require.NoError(t, err, "Should connect to proxy server")
defer func() { _ = sshClientConn.Close() }()
sshClient := cryptossh.NewClient(sshClientConn, chans, reqs)
session, err := sshClient.NewSession()
require.NoError(t, err, "Should create session through full proxy to backend")
outputCh := make(chan []byte, 1)
errCh := make(chan error, 1)
go func() {
output, err := session.Output("echo hello-from-proxy")
outputCh <- output
errCh <- err
}()
select {
case output := <-outputCh:
err := <-errCh
require.NoError(t, err, "Command should execute successfully through proxy")
assert.Contains(t, string(output), "hello-from-proxy", "Should receive command output through proxy")
case <-time.After(3 * time.Second):
t.Fatal("Command execution timed out")
}
_ = session.Close()
_ = sshClient.Close()
_ = clientConn.Close()
cancel()
}
type mockDaemonServer struct {
proto.UnimplementedDaemonServiceServer
hostKeys map[string][]byte
jwtToken string
}
func (m *mockDaemonServer) GetPeerSSHHostKey(ctx context.Context, req *proto.GetPeerSSHHostKeyRequest) (*proto.GetPeerSSHHostKeyResponse, error) {
key, found := m.hostKeys[req.PeerAddress]
return &proto.GetPeerSSHHostKeyResponse{
Found: found,
SshHostKey: key,
}, nil
}
func (m *mockDaemonServer) RequestJWTAuth(ctx context.Context, req *proto.RequestJWTAuthRequest) (*proto.RequestJWTAuthResponse, error) {
return &proto.RequestJWTAuthResponse{
CachedToken: m.jwtToken,
}, nil
}
func (m *mockDaemonServer) WaitJWTToken(ctx context.Context, req *proto.WaitJWTTokenRequest) (*proto.WaitJWTTokenResponse, error) {
return &proto.WaitJWTTokenResponse{
Token: m.jwtToken,
}, nil
}
type mockDaemon struct {
addr string
server *grpc.Server
impl *mockDaemonServer
}
func startMockDaemon(t *testing.T) *mockDaemon {
t.Helper()
listener, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
impl := &mockDaemonServer{
hostKeys: make(map[string][]byte),
jwtToken: "test-jwt-token",
}
grpcServer := grpc.NewServer()
proto.RegisterDaemonServiceServer(grpcServer, impl)
go func() {
_ = grpcServer.Serve(listener)
}()
return &mockDaemon{
addr: listener.Addr().String(),
server: grpcServer,
impl: impl,
}
}
func (m *mockDaemon) setHostKey(addr string, pubKey []byte) {
m.impl.hostKeys[addr] = pubKey
}
func (m *mockDaemon) setJWTToken(token string) {
m.impl.jwtToken = token
}
func (m *mockDaemon) stop() {
if m.server != nil {
m.server.Stop()
}
}
func mustParsePublicKey(t *testing.T, pubKeyBytes []byte) cryptossh.PublicKey {
t.Helper()
pubKey, _, _, _, err := cryptossh.ParseAuthorizedKey(pubKeyBytes)
require.NoError(t, err)
return pubKey
}
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
t.Helper()
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
t.Helper()
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64.RawURLEncoding.EncodeToString(n),
E: base64.RawURLEncoding.EncodeToString(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
t.Helper()
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}

View File

@@ -9,7 +9,6 @@ import (
"net"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"testing"
@@ -21,15 +20,24 @@ import (
"golang.org/x/crypto/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/testutil"
)
// TestMain handles package-level setup and cleanup
func TestMain(m *testing.M) {
// Guard against infinite recursion when test binary is called as "netbird ssh exec"
// This happens when running tests as non-privileged user with fallback
if len(os.Args) > 2 && os.Args[1] == "ssh" && os.Args[2] == "exec" {
// Just exit with error to break the recursion
fmt.Fprintf(os.Stderr, "Test binary called as 'ssh exec' - preventing infinite recursion\n")
os.Exit(1)
}
// Run tests
code := m.Run()
// Cleanup any created test users
cleanupTestUsers()
testutil.CleanupTestUsers()
os.Exit(code)
}
@@ -50,13 +58,15 @@ func TestSSHServerCompatibility(t *testing.T) {
require.NoError(t, err)
// Generate OpenSSH-compatible keys for client
clientPrivKeyOpenSSH, clientPubKeyOpenSSH, err := generateOpenSSHKey(t)
clientPrivKeyOpenSSH, _, err := generateOpenSSHKey(t)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKeyOpenSSH))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -73,7 +83,7 @@ func TestSSHServerCompatibility(t *testing.T) {
require.NoError(t, err)
// Get appropriate user for SSH connection (handle system accounts)
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
t.Run("basic command execution", func(t *testing.T) {
testSSHCommandExecutionWithUser(t, host, portStr, clientKeyFile, username)
@@ -113,7 +123,7 @@ func testSSHCommandExecutionWithUser(t *testing.T, host, port, keyFile, username
// testSSHInteractiveCommand tests interactive shell session.
func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
@@ -178,7 +188,7 @@ func testSSHInteractiveCommand(t *testing.T, host, port, keyFile string) {
// testSSHPortForwarding tests port forwarding compatibility.
func testSSHPortForwarding(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
testServer, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
@@ -401,7 +411,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
t.Skip("Skipping SSH feature compatibility tests in short mode")
}
if runtime.GOOS == "windows" && isCI() {
if runtime.GOOS == "windows" && testutil.IsCI() {
t.Skip("Skipping Windows SSH compatibility tests in CI due to S4U authentication issues")
}
@@ -438,13 +448,13 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -468,7 +478,7 @@ func TestSSHServerFeatureCompatibility(t *testing.T) {
// testCommandWithFlags tests that commands with flags work properly
func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Test ls with flags
cmd := exec.Command("ssh",
@@ -495,7 +505,7 @@ func testCommandWithFlags(t *testing.T, host, port, keyFile string) {
// testEnvironmentVariables tests that environment is properly set up
func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
cmd := exec.Command("ssh",
"-i", keyFile,
@@ -522,7 +532,7 @@ func testEnvironmentVariables(t *testing.T, host, port, keyFile string) {
// testExitCodes tests that exit codes are properly handled
func testExitCodes(t *testing.T, host, port, keyFile string) {
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Test successful command (exit code 0)
cmd := exec.Command("ssh",
@@ -567,7 +577,7 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
}
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Set up SSH server with specific security settings
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
@@ -575,13 +585,13 @@ func TestSSHServerSecurityFeatures(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -652,7 +662,7 @@ func TestCrossPlatformCompatibility(t *testing.T) {
}
// Get appropriate user for SSH connection
username := getTestUsername(t)
username := testutil.GetTestUsername(t)
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
@@ -660,13 +670,13 @@ func TestCrossPlatformCompatibility(t *testing.T) {
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -710,171 +720,3 @@ func TestCrossPlatformCompatibility(t *testing.T) {
t.Logf("Platform command output: %s", outputStr)
assert.NotEmpty(t, outputStr, "Platform-specific command should produce output")
}
// getTestUsername returns an appropriate username for testing
func getTestUsername(t *testing.T) string {
if runtime.GOOS == "windows" {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
// Check if this is a system account that can't authenticate
if isSystemAccount(currentUser.Username) {
// In CI environments, create a test user; otherwise try Administrator
if isCI() {
if testUser := getOrCreateTestUser(t); testUser != "" {
return testUser
}
} else {
// Try Administrator first for local development
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := getOrCreateTestUser(t); testUser != "" {
return testUser
}
}
}
return currentUser.Username
}
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
return currentUser.Username
}
// isCI checks if we're running in a CI environment
func isCI() bool {
// Check standard CI environment variables
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
// Check for GitHub Actions runner hostname pattern (when running as SYSTEM)
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// isSystemAccount checks if the user is a system account that can't authenticate
func isSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
var compatTestCreatedUsers = make(map[string]bool)
var compatTestUsersToCleanup []string
// registerTestUserCleanup registers a test user for cleanup
func registerTestUserCleanup(username string) {
if !compatTestCreatedUsers[username] {
compatTestCreatedUsers[username] = true
compatTestUsersToCleanup = append(compatTestUsersToCleanup, username)
}
}
// cleanupTestUsers removes all created test users
func cleanupTestUsers() {
for _, username := range compatTestUsersToCleanup {
removeWindowsTestUser(username)
}
compatTestUsersToCleanup = nil
compatTestCreatedUsers = make(map[string]bool)
}
// getOrCreateTestUser creates a test user on Windows if needed
func getOrCreateTestUser(t *testing.T) string {
testUsername := "netbird-test-user"
// Check if user already exists
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
// Try to create the user using PowerShell
if createWindowsTestUser(t, testUsername) {
// Register cleanup for the test user
registerTestUserCleanup(testUsername)
return testUsername
}
return ""
}
// removeWindowsTestUser removes a local user on Windows using PowerShell
func removeWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
// PowerShell command to remove a local user
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// createWindowsTestUser creates a local user on Windows using PowerShell
func createWindowsTestUser(t *testing.T, username string) bool {
if runtime.GOOS != "windows" {
return false
}
// PowerShell command to create a local user
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
t.Logf("Test user creation result: %s", string(output))
return true
}

View File

@@ -0,0 +1,610 @@
package server
import (
"context"
"crypto/rand"
"crypto/rsa"
"encoding/base64"
"encoding/json"
"io"
"math/big"
"net"
"net/http"
"net/http/httptest"
"strconv"
"testing"
"time"
"github.com/golang-jwt/jwt/v5"
cryptossh "golang.org/x/crypto/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/client"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/ssh/testutil"
nbjwt "github.com/netbirdio/netbird/management/server/auth/jwt"
)
func TestJWTEnforcement(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT enforcement tests in short mode")
}
// Set up SSH server
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
t.Run("blocks_without_jwt", func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
if err != nil {
t.Logf("Detection failed: %v", err)
}
t.Logf("Detected server type: %s", serverType)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
_, err = cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
assert.Error(t, err, "SSH connection should fail when JWT is required but not provided")
})
t.Run("allows_when_disabled", func(t *testing.T) {
serverConfigNoJWT := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
serverNoJWT := New(serverConfigNoJWT)
require.False(t, serverNoJWT.jwtEnabled, "JWT should be disabled without config")
serverNoJWT.SetAllowRootLogin(true)
serverAddrNoJWT := StartTestServer(t, serverNoJWT)
defer require.NoError(t, serverNoJWT.Stop())
hostNoJWT, portStrNoJWT, err := net.SplitHostPort(serverAddrNoJWT)
require.NoError(t, err)
portNoJWT, err := strconv.Atoi(portStrNoJWT)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, hostNoJWT, portNoJWT)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdNoJWT, serverType)
assert.False(t, serverType.RequiresJWT())
client, err := connectWithNetBirdClient(t, hostNoJWT, portNoJWT)
require.NoError(t, err)
defer client.Close()
})
}
// setupJWKSServer creates a test HTTP server serving JWKS and returns the server, private key, and URL
func setupJWKSServer(t *testing.T) (*httptest.Server, *rsa.PrivateKey, string) {
privateKey, jwksJSON := generateTestJWKS(t)
server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "application/json")
if _, err := w.Write(jwksJSON); err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
}
}))
return server, privateKey, server.URL
}
// generateTestJWKS creates a test RSA key pair and returns private key and JWKS JSON
func generateTestJWKS(t *testing.T) (*rsa.PrivateKey, []byte) {
privateKey, err := rsa.GenerateKey(rand.Reader, 2048)
require.NoError(t, err)
publicKey := &privateKey.PublicKey
n := publicKey.N.Bytes()
e := publicKey.E
jwk := nbjwt.JSONWebKey{
Kty: "RSA",
Kid: "test-key-id",
Use: "sig",
N: base64RawURLEncode(n),
E: base64RawURLEncode(big.NewInt(int64(e)).Bytes()),
}
jwks := nbjwt.Jwks{
Keys: []nbjwt.JSONWebKey{jwk},
}
jwksJSON, err := json.Marshal(jwks)
require.NoError(t, err)
return privateKey, jwksJSON
}
func base64RawURLEncode(data []byte) string {
return base64.RawURLEncoding.EncodeToString(data)
}
// generateValidJWT creates a valid JWT token for testing
func generateValidJWT(t *testing.T, privateKey *rsa.PrivateKey, issuer, audience string) string {
claims := jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
}
token := jwt.NewWithClaims(jwt.SigningMethodRS256, claims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
return tokenString
}
// connectWithNetBirdClient connects to SSH server using NetBird's SSH client
func connectWithNetBirdClient(t *testing.T, host string, port int) (*client.Client, error) {
t.Helper()
addr := net.JoinHostPort(host, strconv.Itoa(port))
ctx := context.Background()
return client.Dial(ctx, addr, testutil.GetTestUsername(t), client.DialOptions{
InsecureSkipVerify: true,
})
}
// TestJWTDetection tests that server detection correctly identifies JWT-enabled servers
func TestJWTDetection(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT detection test in short mode")
}
jwksServer, _, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
port, err := strconv.Atoi(portStr)
require.NoError(t, err)
dialer := &net.Dialer{Timeout: detection.Timeout}
serverType, err := detection.DetectSSHServerType(context.Background(), dialer, host, port)
require.NoError(t, err)
assert.Equal(t, detection.ServerTypeNetBirdJWT, serverType)
assert.True(t, serverType.RequiresJWT())
}
func TestJWTFailClose(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT fail-close tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
tokenClaims jwt.MapClaims
}{
{
name: "blocks_token_missing_iat",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
},
},
{
name: "blocks_token_missing_sub",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_iss",
tokenClaims: jwt.MapClaims{
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_missing_aud",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_issuer",
tokenClaims: jwt.MapClaims{
"iss": "wrong-issuer",
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_token_wrong_audience",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": "wrong-audience",
"sub": "test-user",
"exp": time.Now().Add(time.Hour).Unix(),
"iat": time.Now().Unix(),
},
},
{
name: "blocks_expired_token",
tokenClaims: jwt.MapClaims{
"iss": issuer,
"aud": audience,
"sub": "test-user",
"exp": time.Now().Add(-time.Hour).Unix(),
"iat": time.Now().Add(-2 * time.Hour).Unix(),
},
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
MaxTokenAge: 3600,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
token := jwt.NewWithClaims(jwt.SigningMethodRS256, tc.tokenClaims)
token.Header["kid"] = "test-key-id"
tokenString, err := token.SignedString(privateKey)
require.NoError(t, err)
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: []cryptossh.AuthMethod{
cryptossh.Password(tokenString),
},
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
assert.Error(t, err, "Authentication should fail (fail-close)")
})
}
}
// TestJWTAuthentication tests JWT authentication with valid/invalid tokens and enforcement for various connection types
func TestJWTAuthentication(t *testing.T) {
if testing.Short() {
t.Skip("Skipping JWT authentication tests in short mode")
}
jwksServer, privateKey, jwksURL := setupJWKSServer(t)
defer jwksServer.Close()
const (
issuer = "https://test-issuer.example.com"
audience = "test-audience"
)
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
testCases := []struct {
name string
token string
wantAuthOK bool
setupServer func(*Server)
testOperation func(*testing.T, *cryptossh.Client, string) error
wantOpSuccess bool
}{
{
name: "allows_shell_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
return session.Shell()
},
wantOpSuccess: true,
},
{
name: "rejects_invalid_token",
token: "invalid",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_shell_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("echo test")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "blocks_command_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
output, err := session.CombinedOutput("ls")
if err != nil {
t.Logf("Command output: %s", string(output))
return err
}
return nil
},
wantOpSuccess: false,
},
{
name: "allows_sftp_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
return session.RequestSubsystem("sftp")
},
wantOpSuccess: true,
},
{
name: "blocks_sftp_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowSFTP(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
session, err := conn.NewSession()
require.NoError(t, err)
defer session.Close()
session.Stdout = io.Discard
session.Stderr = io.Discard
err = session.RequestSubsystem("sftp")
if err == nil {
err = session.Wait()
}
return err
},
wantOpSuccess: false,
},
{
name: "allows_port_forward_with_jwt",
token: "valid",
wantAuthOK: true,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowRemotePortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: true,
},
{
name: "blocks_port_forward_without_jwt",
token: "",
wantAuthOK: false,
setupServer: func(s *Server) {
s.SetAllowRootLogin(true)
s.SetAllowLocalPortForwarding(true)
},
testOperation: func(t *testing.T, conn *cryptossh.Client, _ string) error {
ln, err := conn.Listen("tcp", "127.0.0.1:0")
if ln != nil {
defer ln.Close()
}
return err
},
wantOpSuccess: false,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
jwtConfig := &JWTConfig{
Issuer: issuer,
Audience: audience,
KeysLocation: jwksURL,
}
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: jwtConfig,
}
server := New(serverConfig)
if tc.setupServer != nil {
tc.setupServer(server)
}
serverAddr := StartTestServer(t, server)
defer require.NoError(t, server.Stop())
host, portStr, err := net.SplitHostPort(serverAddr)
require.NoError(t, err)
var authMethods []cryptossh.AuthMethod
if tc.token == "valid" {
token := generateValidJWT(t, privateKey, issuer, audience)
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(token),
}
} else if tc.token == "invalid" {
invalidToken := "eyJhbGciOiJSUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiIxMjM0NTY3ODkwIiwibmFtZSI6IkpvaG4gRG9lIiwiaWF0IjoxNTE2MjM5MDIyfQ.invalid"
authMethods = []cryptossh.AuthMethod{
cryptossh.Password(invalidToken),
}
}
config := &cryptossh.ClientConfig{
User: testutil.GetTestUsername(t),
Auth: authMethods,
HostKeyCallback: cryptossh.InsecureIgnoreHostKey(),
Timeout: 2 * time.Second,
}
conn, err := cryptossh.Dial("tcp", net.JoinHostPort(host, portStr), config)
if tc.wantAuthOK {
require.NoError(t, err, "JWT authentication should succeed")
} else if err != nil {
t.Logf("Connection failed as expected: %v", err)
return
}
if conn != nil {
defer func() {
if err := conn.Close(); err != nil {
t.Logf("close connection: %v", err)
}
}()
}
err = tc.testOperation(t, conn, serverAddr)
if tc.wantOpSuccess {
require.NoError(t, err, "Operation should succeed")
} else {
assert.Error(t, err, "Operation should fail")
}
})
}
}

View File

@@ -2,18 +2,27 @@ package server
import (
"context"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"net"
"net/netip"
"strings"
"sync"
"time"
"github.com/gliderlabs/ssh"
gojwt "github.com/golang-jwt/jwt/v5"
log "github.com/sirupsen/logrus"
cryptossh "golang.org/x/crypto/ssh"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/management/server/auth/jwt"
nbcontext "github.com/netbirdio/netbird/management/server/context"
"github.com/netbirdio/netbird/version"
)
// DefaultSSHPort is the default SSH port of the NetBird's embedded SSH server
@@ -27,6 +36,9 @@ const (
errExitSession = "exit session error: %v"
msgPrivilegedUserDisabled = "privileged user login is disabled"
// DefaultJWTMaxTokenAge is the default maximum age for JWT tokens accepted by the SSH server
DefaultJWTMaxTokenAge = 5 * 60
)
var (
@@ -69,7 +81,6 @@ func (e *UserNotFoundError) Unwrap() error {
}
// safeLogCommand returns a safe representation of the command for logging
// Only logs the first argument to avoid leaking sensitive information
func safeLogCommand(cmd []string) string {
if len(cmd) == 0 {
return "<empty>"
@@ -80,17 +91,14 @@ func safeLogCommand(cmd []string) string {
return fmt.Sprintf("%s [%d args]", cmd[0], len(cmd)-1)
}
// sshConnectionState tracks the state of an SSH connection
type sshConnectionState struct {
hasActivePortForward bool
username string
remoteAddr string
}
// Server is the SSH server implementation
type Server struct {
sshServer *ssh.Server
authorizedKeys map[string]ssh.PublicKey
mu sync.RWMutex
hostKeyPEM []byte
sessions map[SessionKey]ssh.Session
@@ -100,30 +108,53 @@ type Server struct {
allowRemotePortForwarding bool
allowRootLogin bool
allowSFTP bool
jwtEnabled bool
netstackNet *netstack.Net
wgAddress wgaddr.Address
ifIdx int
remoteForwardListeners map[ForwardKey]net.Listener
sshConnections map[*cryptossh.ServerConn]*sshConnectionState
jwtValidator *jwt.Validator
jwtExtractor *jwt.ClaimsExtractor
jwtConfig *JWTConfig
}
// New creates an SSH server instance with the provided host key
func New(hostKeyPEM []byte) *Server {
return &Server{
type JWTConfig struct {
Issuer string
Audience string
KeysLocation string
MaxTokenAge int64
}
// Config contains all SSH server configuration options
type Config struct {
// JWT authentication configuration. If nil, JWT authentication is disabled
JWT *JWTConfig
// HostKey is the SSH server host key in PEM format
HostKeyPEM []byte
}
// New creates an SSH server instance with the provided host key and optional JWT configuration
// If jwtConfig is nil, JWT authentication is disabled
func New(config *Config) *Server {
s := &Server{
mu: sync.RWMutex{},
hostKeyPEM: hostKeyPEM,
authorizedKeys: make(map[string]ssh.PublicKey),
hostKeyPEM: config.HostKeyPEM,
sessions: make(map[SessionKey]ssh.Session),
remoteForwardListeners: make(map[ForwardKey]net.Listener),
sshConnections: make(map[*cryptossh.ServerConn]*sshConnectionState),
jwtEnabled: config.JWT != nil,
jwtConfig: config.JWT,
}
return s
}
// Start runs the SSH server, automatically detecting netstack vs standard networking
// Does all setup synchronously, then starts serving in a goroutine and returns immediately
// Start runs the SSH server
func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
s.mu.Lock()
defer s.mu.Unlock()
@@ -139,7 +170,7 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
sshServer, err := s.createSSHServer(ln.Addr())
if err != nil {
s.cleanupOnError(ln)
s.closeListener(ln)
return fmt.Errorf("create SSH server: %w", err)
}
@@ -154,7 +185,6 @@ func (s *Server) Start(ctx context.Context, addr netip.AddrPort) error {
return nil
}
// createListener creates a network listener based on netstack vs standard networking
func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.Listener, string, error) {
if s.netstackNet != nil {
ln, err := s.netstackNet.ListenTCPAddrPort(addr)
@@ -173,22 +203,15 @@ func (s *Server) createListener(ctx context.Context, addr netip.AddrPort) (net.L
return ln, addr.String(), nil
}
// closeListener safely closes a listener
func (s *Server) closeListener(ln net.Listener) {
if ln == nil {
return
}
if err := ln.Close(); err != nil {
log.Debugf("listener close error: %v", err)
}
}
// cleanupOnError cleans up resources when SSH server creation fails
func (s *Server) cleanupOnError(ln net.Listener) {
if s.ifIdx == 0 || ln == nil {
return
}
s.closeListener(ln)
}
// Stop closes the SSH server
func (s *Server) Stop() error {
s.mu.Lock()
@@ -207,28 +230,6 @@ func (s *Server) Stop() error {
return nil
}
// RemoveAuthorizedKey removes the SSH key for a peer
func (s *Server) RemoveAuthorizedKey(peer string) {
s.mu.Lock()
defer s.mu.Unlock()
delete(s.authorizedKeys, peer)
}
// AddAuthorizedKey adds an SSH key for a peer
func (s *Server) AddAuthorizedKey(peer, newKey string) error {
s.mu.Lock()
defer s.mu.Unlock()
parsedKey, _, _, _, err := ssh.ParseAuthorizedKey([]byte(newKey))
if err != nil {
return fmt.Errorf("parse key: %w", err)
}
s.authorizedKeys[peer] = parsedKey
return nil
}
// SetNetstackNet sets the netstack network for userspace networking
func (s *Server) SetNetstackNet(net *netstack.Net) {
s.mu.Lock()
@@ -243,34 +244,195 @@ func (s *Server) SetNetworkValidation(addr wgaddr.Address) {
s.wgAddress = addr
}
// SetSocketFilter configures eBPF socket filtering for the SSH server
func (s *Server) SetSocketFilter(ifIdx int) {
// ensureJWTValidator initializes the JWT validator and extractor if not already initialized
func (s *Server) ensureJWTValidator() error {
s.mu.RLock()
if s.jwtValidator != nil && s.jwtExtractor != nil {
s.mu.RUnlock()
return nil
}
config := s.jwtConfig
s.mu.RUnlock()
if config == nil {
return fmt.Errorf("JWT config not set")
}
log.Debugf("Initializing JWT validator (issuer: %s, audience: %s)", config.Issuer, config.Audience)
validator := jwt.NewValidator(
config.Issuer,
[]string{config.Audience},
config.KeysLocation,
true,
)
extractor := jwt.NewClaimsExtractor(
jwt.WithAudience(config.Audience),
)
s.mu.Lock()
defer s.mu.Unlock()
s.ifIdx = ifIdx
if s.jwtValidator != nil && s.jwtExtractor != nil {
return nil
}
s.jwtValidator = validator
s.jwtExtractor = extractor
log.Infof("JWT validator initialized successfully")
return nil
}
func (s *Server) publicKeyHandler(ctx ssh.Context, key ssh.PublicKey) bool {
func (s *Server) validateJWTToken(tokenString string) (*gojwt.Token, error) {
s.mu.RLock()
defer s.mu.RUnlock()
jwtValidator := s.jwtValidator
jwtConfig := s.jwtConfig
s.mu.RUnlock()
for _, allowed := range s.authorizedKeys {
if ssh.KeysEqual(allowed, key) {
if ctx != nil {
log.Debugf("SSH key authentication successful for user %s from %s", ctx.User(), ctx.RemoteAddr())
if jwtValidator == nil {
return nil, fmt.Errorf("JWT validator not initialized")
}
token, err := jwtValidator.ValidateAndParse(context.Background(), tokenString)
if err != nil {
if jwtConfig != nil {
if claims, parseErr := s.parseTokenWithoutValidation(tokenString); parseErr == nil {
return nil, fmt.Errorf("validate token (expected issuer=%s, audience=%s, actual issuer=%v, audience=%v): %w",
jwtConfig.Issuer, jwtConfig.Audience, claims["iss"], claims["aud"], err)
}
return true
}
return nil, fmt.Errorf("validate token: %w", err)
}
if ctx != nil {
log.Warnf("SSH key authentication failed for user %s from %s: key not authorized (type: %s, fingerprint: %s)",
ctx.User(), ctx.RemoteAddr(), key.Type(), cryptossh.FingerprintSHA256(key))
if err := s.checkTokenAge(token, jwtConfig); err != nil {
return nil, err
}
return false
return token, nil
}
func (s *Server) checkTokenAge(token *gojwt.Token, jwtConfig *JWTConfig) error {
if jwtConfig == nil || jwtConfig.MaxTokenAge <= 0 {
return nil
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token has invalid claims format (user=%s)", userID)
}
iat, ok := claims["iat"].(float64)
if !ok {
userID := extractUserID(token)
return fmt.Errorf("token missing iat claim (user=%s)", userID)
}
issuedAt := time.Unix(int64(iat), 0)
tokenAge := time.Since(issuedAt)
maxAge := time.Duration(jwtConfig.MaxTokenAge) * time.Second
if tokenAge > maxAge {
userID := getUserIDFromClaims(claims)
return fmt.Errorf("token expired for user=%s: age=%v, max=%v", userID, tokenAge, maxAge)
}
return nil
}
func (s *Server) extractAndValidateUser(token *gojwt.Token) (*nbcontext.UserAuth, error) {
s.mu.RLock()
jwtExtractor := s.jwtExtractor
s.mu.RUnlock()
if jwtExtractor == nil {
userID := extractUserID(token)
return nil, fmt.Errorf("JWT extractor not initialized (user=%s)", userID)
}
userAuth, err := jwtExtractor.ToUserAuth(token)
if err != nil {
userID := extractUserID(token)
return nil, fmt.Errorf("extract user from token (user=%s): %w", userID, err)
}
if !s.hasSSHAccess(&userAuth) {
return nil, fmt.Errorf("user %s does not have SSH access permissions", userAuth.UserId)
}
return &userAuth, nil
}
func (s *Server) hasSSHAccess(userAuth *nbcontext.UserAuth) bool {
return userAuth.UserId != ""
}
func extractUserID(token *gojwt.Token) string {
if token == nil {
return "unknown"
}
claims, ok := token.Claims.(gojwt.MapClaims)
if !ok {
return "unknown"
}
return getUserIDFromClaims(claims)
}
func getUserIDFromClaims(claims gojwt.MapClaims) string {
if sub, ok := claims["sub"].(string); ok && sub != "" {
return sub
}
if userID, ok := claims["user_id"].(string); ok && userID != "" {
return userID
}
if email, ok := claims["email"].(string); ok && email != "" {
return email
}
return "unknown"
}
func (s *Server) parseTokenWithoutValidation(tokenString string) (map[string]interface{}, error) {
parts := strings.Split(tokenString, ".")
if len(parts) != 3 {
return nil, fmt.Errorf("invalid token format")
}
payload, err := base64.RawURLEncoding.DecodeString(parts[1])
if err != nil {
return nil, fmt.Errorf("decode payload: %w", err)
}
var claims map[string]interface{}
if err := json.Unmarshal(payload, &claims); err != nil {
return nil, fmt.Errorf("parse claims: %w", err)
}
return claims, nil
}
func (s *Server) passwordHandler(ctx ssh.Context, password string) bool {
if err := s.ensureJWTValidator(); err != nil {
log.Errorf("JWT validator initialization failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
token, err := s.validateJWTToken(password)
if err != nil {
log.Warnf("JWT authentication failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
userAuth, err := s.extractAndValidateUser(token)
if err != nil {
log.Warnf("User validation failed for user %s from %s: %v", ctx.User(), ctx.RemoteAddr(), err)
return false
}
log.Infof("JWT authentication successful for user %s (JWT user ID: %s) from %s", ctx.User(), userAuth.UserId, ctx.RemoteAddr())
return true
}
// markConnectionActivePortForward marks an SSH connection as having an active port forward
func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn, username, remoteAddr string) {
s.mu.Lock()
defer s.mu.Unlock()
@@ -286,14 +448,12 @@ func (s *Server) markConnectionActivePortForward(sshConn *cryptossh.ServerConn,
}
}
// connectionCloseHandler cleans up connection state when SSH connections fail/close
func (s *Server) connectionCloseHandler(conn net.Conn, err error) {
// We can't extract the SSH connection from net.Conn directly
// Connection cleanup will happen during session cleanup or via timeout
log.Debugf("SSH connection failed for %s: %v", conn.RemoteAddr(), err)
}
// findSessionKeyByContext finds the session key by matching SSH connection context
func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
if ctx == nil {
return "unknown"
@@ -319,14 +479,13 @@ func (s *Server) findSessionKeyByContext(ctx ssh.Context) SessionKey {
// Return a temporary key that we'll fix up later
if ctx.User() != "" && ctx.RemoteAddr() != nil {
tempKey := SessionKey(fmt.Sprintf("%s@%s", ctx.User(), ctx.RemoteAddr().String()))
log.Debugf("using temporary session key for port forward tracking: %s", tempKey)
log.Debugf("Using temporary session key for early port forward tracking: %s (will be updated when session established)", tempKey)
return tempKey
}
return "unknown"
}
// connectionValidator validates incoming connections based on source IP
func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
s.mu.RLock()
netbirdNetwork := s.wgAddress.Network
@@ -340,8 +499,8 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
remoteAddr := conn.RemoteAddr()
tcpAddr, ok := remoteAddr.(*net.TCPAddr)
if !ok {
log.Debugf("SSH connection from non-TCP address %s allowed", remoteAddr)
return conn
log.Warnf("SSH connection rejected: non-TCP address %s", remoteAddr)
return nil
}
remoteIP, ok := netip.AddrFromSlice(tcpAddr.IP)
@@ -357,15 +516,14 @@ func (s *Server) connectionValidator(_ ssh.Context, conn net.Conn) net.Conn {
}
if !netbirdNetwork.Contains(remoteIP) {
log.Warnf("SSH connection rejected from non-NetBird IP %s (allowed range: %s)", remoteIP, netbirdNetwork)
log.Warnf("SSH connection rejected from non-NetBird IP %s", remoteIP)
return nil
}
log.Debugf("SSH connection from %s allowed", remoteIP)
log.Infof("SSH connection from NetBird peer %s allowed", remoteIP)
return conn
}
// isShutdownError checks if the error is expected during normal shutdown
func isShutdownError(err error) bool {
if errors.Is(err, net.ErrClosed) {
return true
@@ -379,12 +537,16 @@ func isShutdownError(err error) bool {
return false
}
// createSSHServer creates and configures the SSH server
func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
if err := enableUserSwitching(); err != nil {
log.Warnf("failed to enable user switching: %v", err)
}
serverVersion := fmt.Sprintf("%s-%s", detection.ServerIdentifier, version.NetbirdVersion())
if s.jwtEnabled {
serverVersion += " " + detection.JWTRequiredMarker
}
server := &ssh.Server{
Addr: addr.String(),
Handler: s.sessionHandler,
@@ -402,6 +564,11 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
},
ConnCallback: s.connectionValidator,
ConnectionFailedCallback: s.connectionCloseHandler,
Version: serverVersion,
}
if s.jwtEnabled {
server.PasswordHandler = s.passwordHandler
}
hostKeyPEM := ssh.HostKeyPEM(s.hostKeyPEM)
@@ -413,14 +580,12 @@ func (s *Server) createSSHServer(addr net.Addr) (*ssh.Server, error) {
return server, nil
}
// storeRemoteForwardListener stores a remote forward listener for cleanup
func (s *Server) storeRemoteForwardListener(key ForwardKey, ln net.Listener) {
s.mu.Lock()
defer s.mu.Unlock()
s.remoteForwardListeners[key] = ln
}
// removeRemoteForwardListener removes and closes a remote forward listener
func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
s.mu.Lock()
defer s.mu.Unlock()
@@ -438,7 +603,6 @@ func (s *Server) removeRemoteForwardListener(key ForwardKey) bool {
return true
}
// directTCPIPHandler handles direct-tcpip channel requests for local port forwarding with privilege validation
func (s *Server) directTCPIPHandler(srv *ssh.Server, conn *cryptossh.ServerConn, newChan cryptossh.NewChannel, ctx ssh.Context) {
var payload struct {
Host string

View File

@@ -22,12 +22,6 @@ func TestServer_RootLoginRestriction(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
tests := []struct {
name string
allowRoot bool
@@ -117,10 +111,12 @@ func TestServer_RootLoginRestriction(t *testing.T) {
defer cleanup()
// Create server with specific configuration
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(tt.allowRoot)
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
// Test the userNameLookup method directly
user, err := server.userNameLookup(tt.username)
@@ -196,7 +192,11 @@ func TestServer_PortForwardingRestriction(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create server with specific configuration
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowLocalPortForwarding(tt.allowLocalForwarding)
server.SetAllowRemotePortForwarding(tt.allowRemoteForwarding)
@@ -234,17 +234,13 @@ func TestServer_PortConflictHandling(t *testing.T) {
hostKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server
server := New(hostKey)
server.SetAllowRootLogin(true) // Allow root login for testing
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowRootLogin(true)
serverAddr := StartTestServer(t, server)
defer func() {
@@ -263,7 +259,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
ctx1, cancel1 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel1()
client1, err := sshclient.DialInsecure(ctx1, serverAddr, currentUser.Username)
client1, err := sshclient.Dial(ctx1, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client1.Close()
@@ -274,7 +272,9 @@ func TestServer_PortConflictHandling(t *testing.T) {
ctx2, cancel2 := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel2()
client2, err := sshclient.DialInsecure(ctx2, serverAddr, currentUser.Username)
client2, err := sshclient.Dial(ctx2, serverAddr, currentUser.Username, sshclient.DialOptions{
InsecureSkipVerify: true,
})
require.NoError(t, err)
defer func() {
err := client2.Close()

View File

@@ -7,11 +7,9 @@ import (
"net/netip"
"os/user"
"runtime"
"strings"
"testing"
"time"
"github.com/gliderlabs/ssh"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
cryptossh "golang.org/x/crypto/ssh"
@@ -19,82 +17,15 @@ import (
nbssh "github.com/netbirdio/netbird/client/ssh"
)
func TestServer_AddAuthorizedKey(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
keys := map[string][]byte{}
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey(peer, string(remotePubKey))
require.NoError(t, err)
keys[peer] = remotePubKey
}
for peer, remotePubKey := range keys {
k, ok := server.authorizedKeys[peer]
assert.True(t, ok, "expecting remotePeer key to be found in authorizedKeys")
assert.Equal(t, string(remotePubKey), strings.TrimSpace(string(cryptossh.MarshalAuthorizedKey(k))))
}
}
func TestServer_RemoveAuthorizedKey(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
err = server.AddAuthorizedKey("remotePeer", string(remotePubKey))
require.NoError(t, err)
server.RemoveAuthorizedKey("remotePeer")
_, ok := server.authorizedKeys["remotePeer"]
assert.False(t, ok, "expecting remotePeer's SSH key to be removed")
}
func TestServer_PubKeyHandler(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
var keys []ssh.PublicKey
for i := 0; i < 10; i++ {
peer := fmt.Sprintf("%s-%d", "remotePeer", i)
remotePrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
remotePubKey, err := nbssh.GeneratePublicKey(remotePrivKey)
require.NoError(t, err)
remoteParsedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(remotePubKey)
require.NoError(t, err)
err = server.AddAuthorizedKey(peer, string(remotePubKey))
require.NoError(t, err)
keys = append(keys, remoteParsedPubKey)
}
for _, key := range keys {
accepted := server.publicKeyHandler(nil, key)
assert.True(t, accepted, "SSH key should be accepted")
}
}
func TestServer_StartStop(t *testing.T) {
key, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(key)
serverConfig := &Config{
HostKeyPEM: key,
JWT: nil,
}
server := New(serverConfig)
err = server.Stop()
assert.NoError(t, err)
@@ -108,15 +39,13 @@ func TestSSHServerIntegration(t *testing.T) {
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with random port
server := New(hostKey)
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server in background
serverAddr := "127.0.0.1:0"
@@ -212,13 +141,13 @@ func TestSSHServerMultipleConnections(t *testing.T) {
// Generate client key pair
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientPubKey, err := nbssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server
server := New(hostKey)
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
@@ -324,20 +253,12 @@ func TestSSHServerNoAuthMode(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Generate authorized key
authorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
authorizedPubKey, err := nbssh.GeneratePublicKey(authorizedPrivKey)
require.NoError(t, err)
// Generate unauthorized key (different from authorized)
unauthorizedPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
// Create server with only one authorized key
server := New(hostKey)
err = server.AddAuthorizedKey("authorized-peer", string(authorizedPubKey))
require.NoError(t, err)
// Create server
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
// Start server
serverAddr := "127.0.0.1:0"
@@ -377,8 +298,10 @@ func TestSSHServerNoAuthMode(t *testing.T) {
require.NoError(t, err)
}()
// Parse unauthorized private key
unauthorizedSigner, err := cryptossh.ParsePrivateKey(unauthorizedPrivKey)
// Generate a client private key for SSH protocol (server doesn't check it)
clientPrivKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
clientSigner, err := cryptossh.ParsePrivateKey(clientPrivKey)
require.NoError(t, err)
// Parse server host key
@@ -390,17 +313,17 @@ func TestSSHServerNoAuthMode(t *testing.T) {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user for test")
// Try to connect with unauthorized key
// Try to connect with client key
config := &cryptossh.ClientConfig{
User: currentUser.Username,
Auth: []cryptossh.AuthMethod{
cryptossh.PublicKeys(unauthorizedSigner),
cryptossh.PublicKeys(clientSigner),
},
HostKeyCallback: cryptossh.FixedHostKey(hostPubKey),
Timeout: 3 * time.Second,
}
// This should succeed in no-auth mode
// This should succeed in no-auth mode (server doesn't verify keys)
conn, err := cryptossh.Dial("tcp", serverAddr, config)
assert.NoError(t, err, "Connection should succeed in no-auth mode")
if conn != nil {
@@ -412,7 +335,11 @@ func TestSSHServerStartStopCycle(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
serverAddr := "127.0.0.1:0"
// Test multiple start/stop cycles
@@ -485,8 +412,17 @@ func TestSSHServer_PortForwardingConfiguration(t *testing.T) {
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
require.NoError(t, err)
server1 := New(hostKey)
server2 := New(hostKey)
serverConfig1 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server1 := New(serverConfig1)
serverConfig2 := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server2 := New(serverConfig2)
assert.False(t, server1.allowLocalPortForwarding, "Local port forwarding should be disabled by default for security")
assert.False(t, server1.allowRemotePortForwarding, "Remote port forwarding should be disabled by default for security")

View File

@@ -35,17 +35,15 @@ func TestSSHServer_SFTPSubsystem(t *testing.T) {
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with SFTP enabled
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(true)
server.SetAllowRootLogin(true) // Allow root login for testing
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
server.SetAllowRootLogin(true)
// Start server
serverAddr := "127.0.0.1:0"
@@ -144,17 +142,15 @@ func TestSSHServer_SFTPDisabled(t *testing.T) {
// Generate client key pair
clientPrivKey, err := ssh.GeneratePrivateKey(ssh.ED25519)
require.NoError(t, err)
clientPubKey, err := ssh.GeneratePublicKey(clientPrivKey)
require.NoError(t, err)
// Create server with SFTP disabled
server := New(hostKey)
serverConfig := &Config{
HostKeyPEM: hostKey,
JWT: nil,
}
server := New(serverConfig)
server.SetAllowSFTP(false)
// Add client's public key as authorized
err = server.AddAuthorizedKey("test-peer", string(clientPubKey))
require.NoError(t, err)
// Start server
serverAddr := "127.0.0.1:0"
started := make(chan string, 1)

View File

@@ -14,7 +14,6 @@ func StartTestServer(t *testing.T, server *Server) string {
errChan := make(chan error, 1)
go func() {
// Get a free port
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
errChan <- err
@@ -26,9 +25,12 @@ func StartTestServer(t *testing.T, server *Server) string {
return
}
started <- actualAddr
addrPort := netip.MustParseAddrPort(actualAddr)
errChan <- server.Start(context.Background(), addrPort)
if err := server.Start(context.Background(), addrPort); err != nil {
errChan <- err
return
}
started <- actualAddr
}()
select {

View File

@@ -0,0 +1,172 @@
package testutil
import (
"fmt"
"log"
"os"
"os/exec"
"os/user"
"runtime"
"strings"
"testing"
"github.com/stretchr/testify/require"
)
var testCreatedUsers = make(map[string]bool)
var testUsersToCleanup []string
// GetTestUsername returns an appropriate username for testing
func GetTestUsername(t *testing.T) string {
if runtime.GOOS == "windows" {
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
if IsSystemAccount(currentUser.Username) {
if IsCI() {
if testUser := GetOrCreateTestUser(t); testUser != "" {
return testUser
}
} else {
if _, err := user.Lookup("Administrator"); err == nil {
return "Administrator"
}
if testUser := GetOrCreateTestUser(t); testUser != "" {
return testUser
}
}
}
return currentUser.Username
}
currentUser, err := user.Current()
require.NoError(t, err, "Should be able to get current user")
return currentUser.Username
}
// IsCI checks if we're running in a CI environment
func IsCI() bool {
if os.Getenv("GITHUB_ACTIONS") == "true" || os.Getenv("CI") == "true" {
return true
}
hostname, err := os.Hostname()
if err == nil && strings.HasPrefix(hostname, "runner") {
return true
}
return false
}
// IsSystemAccount checks if the user is a system account that can't authenticate
func IsSystemAccount(username string) bool {
systemAccounts := []string{
"system",
"NT AUTHORITY\\SYSTEM",
"NT AUTHORITY\\LOCAL SERVICE",
"NT AUTHORITY\\NETWORK SERVICE",
}
for _, sysAccount := range systemAccounts {
if strings.EqualFold(username, sysAccount) {
return true
}
}
return false
}
// RegisterTestUserCleanup registers a test user for cleanup
func RegisterTestUserCleanup(username string) {
if !testCreatedUsers[username] {
testCreatedUsers[username] = true
testUsersToCleanup = append(testUsersToCleanup, username)
}
}
// CleanupTestUsers removes all created test users
func CleanupTestUsers() {
for _, username := range testUsersToCleanup {
RemoveWindowsTestUser(username)
}
testUsersToCleanup = nil
testCreatedUsers = make(map[string]bool)
}
// GetOrCreateTestUser creates a test user on Windows if needed
func GetOrCreateTestUser(t *testing.T) string {
testUsername := "netbird-test-user"
if _, err := user.Lookup(testUsername); err == nil {
return testUsername
}
if CreateWindowsTestUser(t, testUsername) {
RegisterTestUserCleanup(testUsername)
return testUsername
}
return ""
}
// RemoveWindowsTestUser removes a local user on Windows using PowerShell
func RemoveWindowsTestUser(username string) {
if runtime.GOOS != "windows" {
return
}
psCmd := fmt.Sprintf(`
try {
Remove-LocalUser -Name "%s" -ErrorAction Stop
Write-Output "User removed successfully"
} catch {
if ($_.Exception.Message -like "*cannot be found*") {
Write-Output "User not found (already removed)"
} else {
Write-Error $_.Exception.Message
}
}
`, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
log.Printf("Failed to remove test user %s: %v, output: %s", username, err, string(output))
} else {
log.Printf("Test user %s cleanup result: %s", username, string(output))
}
}
// CreateWindowsTestUser creates a local user on Windows using PowerShell
func CreateWindowsTestUser(t *testing.T, username string) bool {
if runtime.GOOS != "windows" {
return false
}
psCmd := fmt.Sprintf(`
try {
$password = ConvertTo-SecureString "TestPassword123!" -AsPlainText -Force
New-LocalUser -Name "%s" -Password $password -Description "NetBird test user" -UserMayNotChangePassword -PasswordNeverExpires
Add-LocalGroupMember -Group "Users" -Member "%s"
Write-Output "User created successfully"
} catch {
if ($_.Exception.Message -like "*already exists*") {
Write-Output "User already exists"
} else {
Write-Error $_.Exception.Message
exit 1
}
}
`, username, username)
cmd := exec.Command("powershell", "-Command", psCmd)
output, err := cmd.CombinedOutput()
if err != nil {
t.Logf("Failed to create test user: %v, output: %s", err, string(output))
return false
}
t.Logf("Test user creation result: %s", string(output))
return true
}

View File

@@ -77,6 +77,7 @@ type Info struct {
EnableSSHSFTP bool
EnableSSHLocalPortForwarding bool
EnableSSHRemotePortForwarding bool
DisableSSHAuth bool
}
func (i *Info) SetFlags(
@@ -85,6 +86,7 @@ func (i *Info) SetFlags(
disableClientRoutes, disableServerRoutes,
disableDNS, disableFirewall, blockLANAccess, blockInbound, lazyConnectionEnabled bool,
enableSSHRoot, enableSSHSFTP, enableSSHLocalPortForwarding, enableSSHRemotePortForwarding *bool,
disableSSHAuth *bool,
) {
i.RosenpassEnabled = rosenpassEnabled
i.RosenpassPermissive = rosenpassPermissive
@@ -113,6 +115,9 @@ func (i *Info) SetFlags(
if enableSSHRemotePortForwarding != nil {
i.EnableSSHRemotePortForwarding = *enableSSHRemotePortForwarding
}
if disableSSHAuth != nil {
i.DisableSSHAuth = *disableSSHAuth
}
}
// extractUserAgent extracts Netbird's agent (client) name and version from the outgoing context

View File

@@ -270,6 +270,7 @@ type serviceClient struct {
sEnableSSHSFTP *widget.Check
sEnableSSHLocalPortForward *widget.Check
sEnableSSHRemotePortForward *widget.Check
sDisableSSHAuth *widget.Check
// observable settings over corresponding iMngURL and iPreSharedKey values.
managementURL string
@@ -288,6 +289,7 @@ type serviceClient struct {
enableSSHSFTP bool
enableSSHLocalPortForward bool
enableSSHRemotePortForward bool
disableSSHAuth bool
connected bool
update *version.Update
@@ -437,6 +439,7 @@ func (s *serviceClient) showSettingsUI() {
s.sEnableSSHSFTP = widget.NewCheck("Enable SSH SFTP", nil)
s.sEnableSSHLocalPortForward = widget.NewCheck("Enable SSH Local Port Forwarding", nil)
s.sEnableSSHRemotePortForward = widget.NewCheck("Enable SSH Remote Port Forwarding", nil)
s.sDisableSSHAuth = widget.NewCheck("Disable SSH Authentication", nil)
s.wSettings.SetContent(s.getSettingsForm())
s.wSettings.Resize(fyne.NewSize(600, 400))
@@ -597,6 +600,7 @@ func (s *serviceClient) buildSetConfigRequest(iMngURL string, port, mtu int64) (
req.EnableSSHSFTP = &s.sEnableSSHSFTP.Checked
req.EnableSSHLocalPortForward = &s.sEnableSSHLocalPortForward.Checked
req.EnableSSHRemotePortForward = &s.sEnableSSHRemotePortForward.Checked
req.DisableSSHAuth = &s.sDisableSSHAuth.Checked
if s.iPreSharedKey.Text != censoredPreSharedKey {
req.OptionalPreSharedKey = &s.iPreSharedKey.Text
@@ -682,6 +686,7 @@ func (s *serviceClient) getSSHForm() *widget.Form {
{Text: "Enable SSH SFTP", Widget: s.sEnableSSHSFTP},
{Text: "Enable SSH Local Port Forwarding", Widget: s.sEnableSSHLocalPortForward},
{Text: "Enable SSH Remote Port Forwarding", Widget: s.sEnableSSHRemotePortForward},
{Text: "Disable SSH Authentication", Widget: s.sDisableSSHAuth},
},
}
}
@@ -690,7 +695,8 @@ func (s *serviceClient) hasSSHChanges() bool {
return s.enableSSHRoot != s.sEnableSSHRoot.Checked ||
s.enableSSHSFTP != s.sEnableSSHSFTP.Checked ||
s.enableSSHLocalPortForward != s.sEnableSSHLocalPortForward.Checked ||
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked
s.enableSSHRemotePortForward != s.sEnableSSHRemotePortForward.Checked ||
s.disableSSHAuth != s.sDisableSSHAuth.Checked
}
func (s *serviceClient) login(openURL bool) (*proto.LoginResponse, error) {
@@ -1233,6 +1239,9 @@ func (s *serviceClient) getSrvConfig() {
if cfg.EnableSSHRemotePortForwarding != nil {
s.enableSSHRemotePortForward = *cfg.EnableSSHRemotePortForwarding
}
if cfg.DisableSSHAuth != nil {
s.disableSSHAuth = *cfg.DisableSSHAuth
}
if s.showAdvancedSettings {
s.iMngURL.SetText(s.managementURL)
@@ -1266,6 +1275,9 @@ func (s *serviceClient) getSrvConfig() {
if cfg.EnableSSHRemotePortForwarding != nil {
s.sEnableSSHRemotePortForward.SetChecked(*cfg.EnableSSHRemotePortForwarding)
}
if cfg.DisableSSHAuth != nil {
s.sDisableSSHAuth.SetChecked(*cfg.DisableSSHAuth)
}
}
if s.mNotifications == nil {
@@ -1348,6 +1360,9 @@ func protoConfigToConfig(cfg *proto.GetConfigResponse) *profilemanager.Config {
if cfg.EnableSSHRemotePortForwarding {
config.EnableSSHRemotePortForwarding = &cfg.EnableSSHRemotePortForwarding
}
if cfg.DisableSSHAuth {
config.DisableSSHAuth = &cfg.DisableSSHAuth
}
return &config
}

View File

@@ -245,6 +245,6 @@ func (h *eventHandler) logout(ctx context.Context) error {
}
h.client.getSrvConfig()
return nil
}

View File

@@ -11,6 +11,7 @@ import (
log "github.com/sirupsen/logrus"
netbird "github.com/netbirdio/netbird/client/embed"
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
"github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
@@ -125,10 +126,15 @@ func createSSHMethod(client *netbird.Client) js.Func {
username = args[2].String()
}
var jwtToken string
if len(args) > 3 && !args[3].IsNull() && !args[3].IsUndefined() {
jwtToken = args[3].String()
}
return createPromise(func(resolve, reject js.Value) {
sshClient := ssh.NewClient(client)
if err := sshClient.Connect(host, port, username); err != nil {
if err := sshClient.Connect(host, port, username, jwtToken); err != nil {
reject.Invoke(err.Error())
return
}
@@ -191,12 +197,43 @@ func createPromise(handler func(resolve, reject js.Value)) js.Value {
}))
}
// createDetectSSHServerMethod creates the SSH server detection method
func createDetectSSHServerMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(this js.Value, args []js.Value) any {
if len(args) < 2 {
return js.ValueOf("error: requires host and port")
}
host := args[0].String()
port := args[1].Int()
return createPromise(func(resolve, reject js.Value) {
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
serverType, err := detectSSHServerType(ctx, client, host, port)
if err != nil {
reject.Invoke(err.Error())
return
}
resolve.Invoke(js.ValueOf(serverType.RequiresJWT()))
})
})
}
// detectSSHServerType detects SSH server type using NetBird network connection
func detectSSHServerType(ctx context.Context, client *netbird.Client, host string, port int) (sshdetection.ServerType, error) {
return sshdetection.DetectSSHServerType(ctx, client, host, port)
}
// createClientObject wraps the NetBird client in a JavaScript object
func createClientObject(client *netbird.Client) js.Value {
obj := make(map[string]interface{})
obj["start"] = createStartMethod(client)
obj["stop"] = createStopMethod(client)
obj["detectSSHServerType"] = createDetectSSHServerMethod(client)
obj["createSSHConnection"] = createSSHMethod(client)
obj["proxyRequest"] = createProxyRequestMethod(client)
obj["createRDPProxy"] = createRDPProxyMethod(client)

View File

@@ -13,6 +13,7 @@ import (
"golang.org/x/crypto/ssh"
netbird "github.com/netbirdio/netbird/client/embed"
nbssh "github.com/netbirdio/netbird/client/ssh"
)
const (
@@ -45,34 +46,19 @@ func NewClient(nbClient *netbird.Client) *Client {
}
// Connect establishes an SSH connection through NetBird network
func (c *Client) Connect(host string, port int, username string) error {
func (c *Client) Connect(host string, port int, username, jwtToken string) error {
addr := fmt.Sprintf("%s:%d", host, port)
logrus.Infof("SSH: Connecting to %s as %s", addr, username)
var authMethods []ssh.AuthMethod
nbConfig, err := c.nbClient.GetConfig()
authMethods, err := c.getAuthMethods(jwtToken)
if err != nil {
return fmt.Errorf("get NetBird config: %w", err)
return err
}
if nbConfig.SSHKey == "" {
return fmt.Errorf("no NetBird SSH key available - key should be generated during client initialization")
}
signer, err := parseSSHPrivateKey([]byte(nbConfig.SSHKey))
if err != nil {
return fmt.Errorf("parse NetBird SSH private key: %w", err)
}
pubKey := signer.PublicKey()
logrus.Infof("SSH: Using NetBird key authentication with public key type: %s", pubKey.Type())
authMethods = append(authMethods, ssh.PublicKeys(signer))
config := &ssh.ClientConfig{
User: username,
Auth: authMethods,
HostKeyCallback: ssh.InsecureIgnoreHostKey(),
HostKeyCallback: nbssh.CreateHostKeyCallback(c.nbClient),
Timeout: sshDialTimeout,
}
@@ -96,6 +82,33 @@ func (c *Client) Connect(host string, port int, username string) error {
return nil
}
// getAuthMethods returns SSH authentication methods, preferring JWT if available
func (c *Client) getAuthMethods(jwtToken string) ([]ssh.AuthMethod, error) {
if jwtToken != "" {
logrus.Debugf("SSH: Using JWT password authentication")
return []ssh.AuthMethod{ssh.Password(jwtToken)}, nil
}
logrus.Debugf("SSH: No JWT token, using public key authentication")
nbConfig, err := c.nbClient.GetConfig()
if err != nil {
return nil, fmt.Errorf("get NetBird config: %w", err)
}
if nbConfig.SSHKey == "" {
return nil, fmt.Errorf("no NetBird SSH key available")
}
signer, err := ssh.ParsePrivateKey([]byte(nbConfig.SSHKey))
if err != nil {
return nil, fmt.Errorf("parse NetBird SSH private key: %w", err)
}
logrus.Debugf("SSH: Added public key auth")
return []ssh.AuthMethod{ssh.PublicKeys(signer)}, nil
}
// StartSession starts an SSH session with PTY
func (c *Client) StartSession(cols, rows int) error {
if c.sshClient == nil {

View File

@@ -1,50 +0,0 @@
//go:build js
package ssh
import (
"crypto/x509"
"encoding/pem"
"fmt"
"strings"
"github.com/sirupsen/logrus"
"golang.org/x/crypto/ssh"
)
// parseSSHPrivateKey parses a private key in either SSH or PKCS8 format
func parseSSHPrivateKey(keyPEM []byte) (ssh.Signer, error) {
keyStr := string(keyPEM)
if !strings.Contains(keyStr, "-----BEGIN") {
keyPEM = []byte("-----BEGIN PRIVATE KEY-----\n" + keyStr + "\n-----END PRIVATE KEY-----")
}
signer, err := ssh.ParsePrivateKey(keyPEM)
if err == nil {
return signer, nil
}
logrus.Debugf("SSH: Failed to parse as SSH format: %v", err)
block, _ := pem.Decode(keyPEM)
if block == nil {
keyPreview := string(keyPEM)
if len(keyPreview) > 100 {
keyPreview = keyPreview[:100]
}
return nil, fmt.Errorf("decode PEM block from key: %s", keyPreview)
}
key, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
logrus.Debugf("SSH: Failed to parse as PKCS8: %v", err)
if rsaKey, err := x509.ParsePKCS1PrivateKey(block.Bytes); err == nil {
return ssh.NewSignerFromKey(rsaKey)
}
if ecKey, err := x509.ParseECPrivateKey(block.Bytes); err == nil {
return ssh.NewSignerFromKey(ecKey)
}
return nil, fmt.Errorf("parse private key: %w", err)
}
return ssh.NewSignerFromKey(key)
}