mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-20 01:06:45 +00:00
Add ssh authenatication with jwt (#4550)
This commit is contained in:
@@ -1,7 +1,6 @@
|
||||
package config
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
@@ -12,50 +11,41 @@ import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
const (
|
||||
// EnvDisableSSHConfig is the environment variable to disable SSH config management
|
||||
EnvDisableSSHConfig = "NB_DISABLE_SSH_CONFIG"
|
||||
|
||||
// EnvForceSSHConfig is the environment variable to force SSH config generation even with many peers
|
||||
EnvForceSSHConfig = "NB_FORCE_SSH_CONFIG"
|
||||
|
||||
// MaxPeersForSSHConfig is the default maximum number of peers before SSH config generation is disabled
|
||||
MaxPeersForSSHConfig = 200
|
||||
|
||||
// fileWriteTimeout is the timeout for file write operations
|
||||
fileWriteTimeout = 2 * time.Second
|
||||
)
|
||||
|
||||
// isSSHConfigDisabled checks if SSH config management is disabled via environment variable
|
||||
func isSSHConfigDisabled() bool {
|
||||
value := os.Getenv(EnvDisableSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse as boolean, default to true if non-empty but invalid
|
||||
disabled, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
// If not a valid boolean, treat any non-empty value as true
|
||||
return true
|
||||
}
|
||||
return disabled
|
||||
}
|
||||
|
||||
// isSSHConfigForced checks if SSH config generation is forced via environment variable
|
||||
func isSSHConfigForced() bool {
|
||||
value := os.Getenv(EnvForceSSHConfig)
|
||||
if value == "" {
|
||||
return false
|
||||
}
|
||||
|
||||
// Parse as boolean, default to true if non-empty but invalid
|
||||
forced, err := strconv.ParseBool(value)
|
||||
if err != nil {
|
||||
// If not a valid boolean, treat any non-empty value as true
|
||||
return true
|
||||
}
|
||||
return forced
|
||||
@@ -92,85 +82,55 @@ func writeFileWithTimeout(filename string, data []byte, perm os.FileMode) error
|
||||
}
|
||||
}
|
||||
|
||||
// writeFileOperationWithTimeout performs a file operation with timeout
|
||||
func writeFileOperationWithTimeout(filename string, operation func() error) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), fileWriteTimeout)
|
||||
defer cancel()
|
||||
|
||||
done := make(chan error, 1)
|
||||
go func() {
|
||||
done <- operation()
|
||||
}()
|
||||
|
||||
select {
|
||||
case err := <-done:
|
||||
return err
|
||||
case <-ctx.Done():
|
||||
return fmt.Errorf("file write timeout after %v: %s", fileWriteTimeout, filename)
|
||||
}
|
||||
}
|
||||
|
||||
// Manager handles SSH client configuration for NetBird peers
|
||||
type Manager struct {
|
||||
sshConfigDir string
|
||||
sshConfigFile string
|
||||
knownHostsDir string
|
||||
knownHostsFile string
|
||||
userKnownHosts string
|
||||
sshConfigDir string
|
||||
sshConfigFile string
|
||||
}
|
||||
|
||||
// PeerHostKey represents a peer's SSH host key information
|
||||
type PeerHostKey struct {
|
||||
// PeerSSHInfo represents a peer's SSH configuration information
|
||||
type PeerSSHInfo struct {
|
||||
Hostname string
|
||||
IP string
|
||||
FQDN string
|
||||
HostKey ssh.PublicKey
|
||||
}
|
||||
|
||||
// NewManager creates a new SSH config manager
|
||||
func NewManager() *Manager {
|
||||
sshConfigDir, knownHostsDir := getSystemSSHPaths()
|
||||
// New creates a new SSH config manager
|
||||
func New() *Manager {
|
||||
sshConfigDir := getSystemSSHConfigDir()
|
||||
return &Manager{
|
||||
sshConfigDir: sshConfigDir,
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: knownHostsDir,
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
sshConfigDir: sshConfigDir,
|
||||
sshConfigFile: nbssh.NetBirdSSHConfigFile,
|
||||
}
|
||||
}
|
||||
|
||||
// getSystemSSHPaths returns platform-specific SSH configuration paths
|
||||
func getSystemSSHPaths() (configDir, knownHostsDir string) {
|
||||
switch runtime.GOOS {
|
||||
case "windows":
|
||||
configDir, knownHostsDir = getWindowsSSHPaths()
|
||||
default:
|
||||
// Unix-like systems (Linux, macOS, etc.)
|
||||
configDir = "/etc/ssh/ssh_config.d"
|
||||
knownHostsDir = "/etc/ssh/ssh_known_hosts.d"
|
||||
// getSystemSSHConfigDir returns platform-specific SSH configuration directory
|
||||
func getSystemSSHConfigDir() string {
|
||||
if runtime.GOOS == "windows" {
|
||||
return getWindowsSSHConfigDir()
|
||||
}
|
||||
return configDir, knownHostsDir
|
||||
return nbssh.UnixSSHConfigDir
|
||||
}
|
||||
|
||||
func getWindowsSSHPaths() (configDir, knownHostsDir string) {
|
||||
func getWindowsSSHConfigDir() string {
|
||||
programData := os.Getenv("PROGRAMDATA")
|
||||
if programData == "" {
|
||||
programData = `C:\ProgramData`
|
||||
}
|
||||
configDir = filepath.Join(programData, "ssh", "ssh_config.d")
|
||||
knownHostsDir = filepath.Join(programData, "ssh", "ssh_known_hosts.d")
|
||||
return configDir, knownHostsDir
|
||||
return filepath.Join(programData, nbssh.WindowsSSHConfigDir)
|
||||
}
|
||||
|
||||
// SetupSSHClientConfig creates SSH client configuration for NetBird peers
|
||||
func (m *Manager) SetupSSHClientConfig(peerKeys []PeerHostKey) error {
|
||||
if !shouldGenerateSSHConfig(len(peerKeys)) {
|
||||
m.logSkipReason(len(peerKeys))
|
||||
func (m *Manager) SetupSSHClientConfig(peers []PeerSSHInfo) error {
|
||||
if !shouldGenerateSSHConfig(len(peers)) {
|
||||
m.logSkipReason(len(peers))
|
||||
return nil
|
||||
}
|
||||
|
||||
knownHostsPath := m.getKnownHostsPath()
|
||||
sshConfig := m.buildSSHConfig(peerKeys, knownHostsPath)
|
||||
sshConfig, err := m.buildSSHConfig(peers)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build SSH config: %w", err)
|
||||
}
|
||||
return m.writeSSHConfig(sshConfig)
|
||||
}
|
||||
|
||||
@@ -183,21 +143,24 @@ func (m *Manager) logSkipReason(peerCount int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) getKnownHostsPath() string {
|
||||
knownHostsPath, err := m.setupKnownHostsFile()
|
||||
if err != nil {
|
||||
log.Warnf("Failed to setup known_hosts file: %v", err)
|
||||
return "/dev/null"
|
||||
}
|
||||
return knownHostsPath
|
||||
}
|
||||
|
||||
func (m *Manager) buildSSHConfig(peerKeys []PeerHostKey, knownHostsPath string) string {
|
||||
func (m *Manager) buildSSHConfig(peers []PeerSSHInfo) (string, error) {
|
||||
sshConfig := m.buildConfigHeader()
|
||||
for _, peer := range peerKeys {
|
||||
sshConfig += m.buildPeerConfig(peer, knownHostsPath)
|
||||
|
||||
var allHostPatterns []string
|
||||
for _, peer := range peers {
|
||||
hostPatterns := m.buildHostPatterns(peer)
|
||||
allHostPatterns = append(allHostPatterns, hostPatterns...)
|
||||
}
|
||||
return sshConfig
|
||||
|
||||
if len(allHostPatterns) > 0 {
|
||||
peerConfig, err := m.buildPeerConfig(allHostPatterns)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
sshConfig += peerConfig
|
||||
}
|
||||
|
||||
return sshConfig, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildConfigHeader() string {
|
||||
@@ -209,25 +172,49 @@ func (m *Manager) buildConfigHeader() string {
|
||||
"#\n\n"
|
||||
}
|
||||
|
||||
func (m *Manager) buildPeerConfig(peer PeerHostKey, knownHostsPath string) string {
|
||||
hostPatterns := m.buildHostPatterns(peer)
|
||||
if len(hostPatterns) == 0 {
|
||||
return ""
|
||||
func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
uniquePatterns := make(map[string]bool)
|
||||
var deduplicatedPatterns []string
|
||||
for _, pattern := range allHostPatterns {
|
||||
if !uniquePatterns[pattern] {
|
||||
uniquePatterns[pattern] = true
|
||||
deduplicatedPatterns = append(deduplicatedPatterns, pattern)
|
||||
}
|
||||
}
|
||||
|
||||
hostLine := strings.Join(hostPatterns, " ")
|
||||
execPath, err := m.getNetBirdExecutablePath()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||
}
|
||||
|
||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||
config += " # NetBird peer-specific configuration\n"
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += m.buildHostKeyConfig(knownHostsPath)
|
||||
config += " LogLevel ERROR\n\n"
|
||||
return config
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
} else {
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p 2>/dev/null\"\n", execPath)
|
||||
}
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||
config += " StrictHostKeyChecking no\n"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += " UserKnownHostsFile NUL\n"
|
||||
} else {
|
||||
config += " UserKnownHostsFile /dev/null\n"
|
||||
}
|
||||
|
||||
config += " CheckHostIP no\n"
|
||||
config += " LogLevel ERROR\n\n"
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
|
||||
func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
||||
var hostPatterns []string
|
||||
if peer.IP != "" {
|
||||
hostPatterns = append(hostPatterns, peer.IP)
|
||||
@@ -241,280 +228,55 @@ func (m *Manager) buildHostPatterns(peer PeerHostKey) []string {
|
||||
return hostPatterns
|
||||
}
|
||||
|
||||
func (m *Manager) buildHostKeyConfig(knownHostsPath string) string {
|
||||
if knownHostsPath == "/dev/null" {
|
||||
return " StrictHostKeyChecking no\n" +
|
||||
" UserKnownHostsFile /dev/null\n"
|
||||
}
|
||||
return " StrictHostKeyChecking yes\n" +
|
||||
fmt.Sprintf(" UserKnownHostsFile %s\n", knownHostsPath)
|
||||
}
|
||||
|
||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
|
||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||
log.Warnf("Failed to create SSH config directory %s: %v", m.sshConfigDir, err)
|
||||
return m.setupUserConfig(sshConfig)
|
||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||
}
|
||||
|
||||
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||
log.Warnf("Failed to write SSH config file %s: %v", sshConfigPath, err)
|
||||
return m.setupUserConfig(sshConfig)
|
||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||
}
|
||||
|
||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupUserConfig creates SSH config in user's directory as fallback
|
||||
func (m *Manager) setupUserConfig(sshConfig string) error {
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
return fmt.Errorf("get user home directory: %w", err)
|
||||
}
|
||||
|
||||
userSSHDir := filepath.Join(homeDir, ".ssh")
|
||||
userConfigPath := filepath.Join(userSSHDir, "config")
|
||||
|
||||
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
|
||||
return fmt.Errorf("create user SSH directory: %w", err)
|
||||
}
|
||||
|
||||
// Check if NetBird config already exists in user config
|
||||
exists, err := m.configExists(userConfigPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("check existing config: %w", err)
|
||||
}
|
||||
|
||||
if exists {
|
||||
log.Debugf("NetBird SSH config already exists in %s", userConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// Append NetBird config to user's SSH config with timeout
|
||||
if err := writeFileOperationWithTimeout(userConfigPath, func() error {
|
||||
file, err := os.OpenFile(userConfigPath, os.O_APPEND|os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open user SSH config: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := file.Close(); err != nil {
|
||||
log.Debugf("user SSH config file close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
if _, err := fmt.Fprintf(file, "\n%s", sshConfig); err != nil {
|
||||
return fmt.Errorf("write to user SSH config: %w", err)
|
||||
}
|
||||
return nil
|
||||
}); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Infof("Added NetBird SSH config to user config: %s", userConfigPath)
|
||||
return nil
|
||||
}
|
||||
|
||||
// configExists checks if NetBird SSH config already exists
|
||||
func (m *Manager) configExists(configPath string) (bool, error) {
|
||||
file, err := os.Open(configPath)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("open SSH config file: %w", err)
|
||||
}
|
||||
defer file.Close()
|
||||
|
||||
scanner := bufio.NewScanner(file)
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.Contains(line, "NetBird SSH client configuration") {
|
||||
return true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return false, scanner.Err()
|
||||
}
|
||||
|
||||
// RemoveSSHClientConfig removes NetBird SSH configuration
|
||||
func (m *Manager) RemoveSSHClientConfig() error {
|
||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||
|
||||
// Remove system-wide config if it exists
|
||||
if err := os.Remove(sshConfigPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Warnf("Failed to remove system SSH config %s: %v", sshConfigPath, err)
|
||||
} else if err == nil {
|
||||
err := os.Remove(sshConfigPath)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
return fmt.Errorf("remove SSH config %s: %w", sshConfigPath, err)
|
||||
}
|
||||
if err == nil {
|
||||
log.Infof("Removed NetBird SSH config: %s", sshConfigPath)
|
||||
}
|
||||
|
||||
// Also try to clean up user config
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user home directory: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
userConfigPath := filepath.Join(homeDir, ".ssh", "config")
|
||||
if err := m.removeFromUserConfig(userConfigPath); err != nil {
|
||||
log.Warnf("Failed to clean user SSH config: %v", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// removeFromUserConfig removes NetBird section from user's SSH config
|
||||
func (m *Manager) removeFromUserConfig(configPath string) error {
|
||||
// This is complex to implement safely, so for now just log
|
||||
// In practice, the system-wide config takes precedence anyway
|
||||
log.Debugf("NetBird SSH config cleanup from user config not implemented")
|
||||
return nil
|
||||
}
|
||||
|
||||
// setupKnownHostsFile creates and returns the path to NetBird known_hosts file
|
||||
func (m *Manager) setupKnownHostsFile() (string, error) {
|
||||
// Try system-wide known_hosts first
|
||||
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
|
||||
if err := os.MkdirAll(m.knownHostsDir, 0755); err == nil {
|
||||
// Create empty file if it doesn't exist
|
||||
if _, err := os.Stat(knownHostsPath); os.IsNotExist(err) {
|
||||
if err := writeFileWithTimeout(knownHostsPath, []byte("# NetBird SSH known hosts\n"), 0644); err == nil {
|
||||
log.Debugf("Created NetBird known_hosts file: %s", knownHostsPath)
|
||||
return knownHostsPath, nil
|
||||
}
|
||||
} else if err == nil {
|
||||
return knownHostsPath, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Fallback to user directory
|
||||
homeDir, err := os.UserHomeDir()
|
||||
func (m *Manager) getNetBirdExecutablePath() (string, error) {
|
||||
execPath, err := os.Executable()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("get user home directory: %w", err)
|
||||
return "", fmt.Errorf("retrieve executable path: %w", err)
|
||||
}
|
||||
|
||||
userSSHDir := filepath.Join(homeDir, ".ssh")
|
||||
if err := os.MkdirAll(userSSHDir, 0700); err != nil {
|
||||
return "", fmt.Errorf("create user SSH directory: %w", err)
|
||||
}
|
||||
|
||||
userKnownHostsPath := filepath.Join(userSSHDir, m.userKnownHosts)
|
||||
if _, err := os.Stat(userKnownHostsPath); os.IsNotExist(err) {
|
||||
if err := writeFileWithTimeout(userKnownHostsPath, []byte("# NetBird SSH known hosts\n"), 0600); err != nil {
|
||||
return "", fmt.Errorf("create user known_hosts file: %w", err)
|
||||
}
|
||||
log.Debugf("Created NetBird user known_hosts file: %s", userKnownHostsPath)
|
||||
}
|
||||
|
||||
return userKnownHostsPath, nil
|
||||
}
|
||||
|
||||
// UpdatePeerHostKeys updates the known_hosts file with peer host keys
|
||||
func (m *Manager) UpdatePeerHostKeys(peerKeys []PeerHostKey) error {
|
||||
peerCount := len(peerKeys)
|
||||
|
||||
// Check if SSH config should be generated
|
||||
if !shouldGenerateSSHConfig(peerCount) {
|
||||
if isSSHConfigDisabled() {
|
||||
log.Debugf("SSH config management disabled via %s", EnvDisableSSHConfig)
|
||||
} else {
|
||||
log.Infof("SSH known_hosts update skipped: too many peers (%d > %d). Use %s=true to force.",
|
||||
peerCount, MaxPeersForSSHConfig, EnvForceSSHConfig)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
knownHostsPath, err := m.setupKnownHostsFile()
|
||||
realPath, err := filepath.EvalSymlinks(execPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("setup known_hosts file: %w", err)
|
||||
log.Debugf("symlink resolution failed: %v", err)
|
||||
return execPath, nil
|
||||
}
|
||||
|
||||
// Create updated known_hosts content - NetBird file should only contain NetBird entries
|
||||
var updatedContent strings.Builder
|
||||
updatedContent.WriteString("# NetBird SSH known hosts\n")
|
||||
updatedContent.WriteString("# Generated automatically - do not edit manually\n\n")
|
||||
|
||||
// Add new NetBird entries - one entry per peer with all hostnames
|
||||
for _, peerKey := range peerKeys {
|
||||
entry := m.formatKnownHostsEntry(peerKey)
|
||||
updatedContent.WriteString(entry)
|
||||
updatedContent.WriteString("\n")
|
||||
}
|
||||
|
||||
// Write updated content
|
||||
if err := writeFileWithTimeout(knownHostsPath, []byte(updatedContent.String()), 0644); err != nil {
|
||||
return fmt.Errorf("write known_hosts file: %w", err)
|
||||
}
|
||||
|
||||
log.Debugf("Updated NetBird known_hosts with %d peer keys: %s", len(peerKeys), knownHostsPath)
|
||||
return nil
|
||||
return realPath, nil
|
||||
}
|
||||
|
||||
// formatKnownHostsEntry formats a peer host key as a known_hosts entry
|
||||
func (m *Manager) formatKnownHostsEntry(peerKey PeerHostKey) string {
|
||||
hostnames := m.getHostnameVariants(peerKey)
|
||||
hostnameList := strings.Join(hostnames, ",")
|
||||
keyString := string(ssh.MarshalAuthorizedKey(peerKey.HostKey))
|
||||
keyString = strings.TrimSpace(keyString)
|
||||
return fmt.Sprintf("%s %s", hostnameList, keyString)
|
||||
// GetSSHConfigDir returns the SSH config directory path
|
||||
func (m *Manager) GetSSHConfigDir() string {
|
||||
return m.sshConfigDir
|
||||
}
|
||||
|
||||
// getHostnameVariants returns all possible hostname variants for a peer
|
||||
func (m *Manager) getHostnameVariants(peerKey PeerHostKey) []string {
|
||||
var hostnames []string
|
||||
|
||||
// Add IP address
|
||||
if peerKey.IP != "" {
|
||||
hostnames = append(hostnames, peerKey.IP)
|
||||
}
|
||||
|
||||
// Add FQDN
|
||||
if peerKey.FQDN != "" {
|
||||
hostnames = append(hostnames, peerKey.FQDN)
|
||||
}
|
||||
|
||||
// Add hostname if different from FQDN
|
||||
if peerKey.Hostname != "" && peerKey.Hostname != peerKey.FQDN {
|
||||
hostnames = append(hostnames, peerKey.Hostname)
|
||||
}
|
||||
|
||||
// Add bracketed IP for non-standard ports (SSH standard)
|
||||
if peerKey.IP != "" {
|
||||
hostnames = append(hostnames, fmt.Sprintf("[%s]:22", peerKey.IP))
|
||||
hostnames = append(hostnames, fmt.Sprintf("[%s]:22022", peerKey.IP))
|
||||
}
|
||||
|
||||
return hostnames
|
||||
}
|
||||
|
||||
// GetKnownHostsPath returns the path to the NetBird known_hosts file
|
||||
func (m *Manager) GetKnownHostsPath() (string, error) {
|
||||
return m.setupKnownHostsFile()
|
||||
}
|
||||
|
||||
// RemoveKnownHostsFile removes the NetBird known_hosts file
|
||||
func (m *Manager) RemoveKnownHostsFile() error {
|
||||
// Remove system-wide known_hosts if it exists
|
||||
knownHostsPath := filepath.Join(m.knownHostsDir, m.knownHostsFile)
|
||||
if err := os.Remove(knownHostsPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Warnf("Failed to remove system known_hosts %s: %v", knownHostsPath, err)
|
||||
} else if err == nil {
|
||||
log.Infof("Removed NetBird known_hosts: %s", knownHostsPath)
|
||||
}
|
||||
|
||||
// Also try to clean up user known_hosts
|
||||
homeDir, err := os.UserHomeDir()
|
||||
if err != nil {
|
||||
log.Debugf("failed to get user home directory: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
userKnownHostsPath := filepath.Join(homeDir, ".ssh", m.userKnownHosts)
|
||||
if err := os.Remove(userKnownHostsPath); err != nil && !os.IsNotExist(err) {
|
||||
log.Warnf("Failed to remove user known_hosts %s: %v", userKnownHostsPath, err)
|
||||
} else if err == nil {
|
||||
log.Infof("Removed NetBird user known_hosts: %s", userKnownHostsPath)
|
||||
}
|
||||
|
||||
return nil
|
||||
// GetSSHConfigFile returns the SSH config file name
|
||||
func (m *Manager) GetSSHConfigFile() string {
|
||||
return m.sshConfigFile
|
||||
}
|
||||
|
||||
@@ -10,81 +10,8 @@ import (
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
)
|
||||
|
||||
func TestManager_UpdatePeerHostKeys(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
}
|
||||
|
||||
// Generate test host keys
|
||||
hostKey1, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
pubKey1, err := ssh.ParsePrivateKey(hostKey1)
|
||||
require.NoError(t, err)
|
||||
|
||||
hostKey2, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
pubKey2, err := ssh.ParsePrivateKey(hostKey2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create test peer host keys
|
||||
peerKeys := []PeerHostKey{
|
||||
{
|
||||
Hostname: "peer1",
|
||||
IP: "100.125.1.1",
|
||||
FQDN: "peer1.nb.internal",
|
||||
HostKey: pubKey1.PublicKey(),
|
||||
},
|
||||
{
|
||||
Hostname: "peer2",
|
||||
IP: "100.125.1.2",
|
||||
FQDN: "peer2.nb.internal",
|
||||
HostKey: pubKey2.PublicKey(),
|
||||
},
|
||||
}
|
||||
|
||||
// Test updating known_hosts
|
||||
err = manager.UpdatePeerHostKeys(peerKeys)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify known_hosts file was created and contains entries
|
||||
knownHostsPath, err := manager.GetKnownHostsPath()
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err := os.ReadFile(knownHostsPath)
|
||||
require.NoError(t, err)
|
||||
|
||||
contentStr := string(content)
|
||||
assert.Contains(t, contentStr, "100.125.1.1")
|
||||
assert.Contains(t, contentStr, "100.125.1.2")
|
||||
assert.Contains(t, contentStr, "peer1.nb.internal")
|
||||
assert.Contains(t, contentStr, "peer2.nb.internal")
|
||||
assert.Contains(t, contentStr, "[100.125.1.1]:22")
|
||||
assert.Contains(t, contentStr, "[100.125.1.1]:22022")
|
||||
|
||||
// Test updating with empty list should preserve structure
|
||||
err = manager.UpdatePeerHostKeys([]PeerHostKey{})
|
||||
require.NoError(t, err)
|
||||
|
||||
content, err = os.ReadFile(knownHostsPath)
|
||||
require.NoError(t, err)
|
||||
assert.Contains(t, string(content), "# NetBird SSH known hosts")
|
||||
}
|
||||
|
||||
func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
// Create temporary directory for test
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
@@ -93,15 +20,25 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Test SSH config generation with empty peer keys
|
||||
err = manager.SetupSSHClientConfig(nil)
|
||||
// Test SSH config generation with peers
|
||||
peers := []PeerSSHInfo{
|
||||
{
|
||||
Hostname: "peer1",
|
||||
IP: "100.125.1.1",
|
||||
FQDN: "peer1.nb.internal",
|
||||
},
|
||||
{
|
||||
Hostname: "peer2",
|
||||
IP: "100.125.1.2",
|
||||
FQDN: "peer2.nb.internal",
|
||||
},
|
||||
}
|
||||
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Read generated config
|
||||
@@ -111,134 +48,39 @@ func TestManager_SetupSSHClientConfig(t *testing.T) {
|
||||
|
||||
configStr := string(content)
|
||||
|
||||
// Since we now use per-peer configurations instead of domain patterns,
|
||||
// we should verify the basic SSH config structure exists
|
||||
// Verify the basic SSH config structure exists
|
||||
assert.Contains(t, configStr, "# NetBird SSH client configuration")
|
||||
assert.Contains(t, configStr, "Generated automatically - do not edit manually")
|
||||
|
||||
// Should not contain /dev/null since we have a proper known_hosts setup
|
||||
assert.NotContains(t, configStr, "UserKnownHostsFile /dev/null")
|
||||
}
|
||||
// Check that peer hostnames are included
|
||||
assert.Contains(t, configStr, "100.125.1.1")
|
||||
assert.Contains(t, configStr, "100.125.1.2")
|
||||
assert.Contains(t, configStr, "peer1.nb.internal")
|
||||
assert.Contains(t, configStr, "peer2.nb.internal")
|
||||
|
||||
func TestManager_GetHostnameVariants(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
peerKey := PeerHostKey{
|
||||
Hostname: "testpeer",
|
||||
IP: "100.125.1.10",
|
||||
FQDN: "testpeer.nb.internal",
|
||||
HostKey: nil, // Not needed for this test
|
||||
}
|
||||
|
||||
variants := manager.getHostnameVariants(peerKey)
|
||||
|
||||
expectedVariants := []string{
|
||||
"100.125.1.10",
|
||||
"testpeer.nb.internal",
|
||||
"testpeer",
|
||||
"[100.125.1.10]:22",
|
||||
"[100.125.1.10]:22022",
|
||||
}
|
||||
|
||||
assert.ElementsMatch(t, expectedVariants, variants)
|
||||
}
|
||||
|
||||
func TestManager_FormatKnownHostsEntry(t *testing.T) {
|
||||
manager := NewManager()
|
||||
|
||||
// Generate test key
|
||||
hostKeyPEM, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
parsedKey, err := ssh.ParsePrivateKey(hostKeyPEM)
|
||||
require.NoError(t, err)
|
||||
|
||||
peerKey := PeerHostKey{
|
||||
Hostname: "testpeer",
|
||||
IP: "100.125.1.10",
|
||||
FQDN: "testpeer.nb.internal",
|
||||
HostKey: parsedKey.PublicKey(),
|
||||
}
|
||||
|
||||
entry := manager.formatKnownHostsEntry(peerKey)
|
||||
|
||||
// Should contain all hostname variants
|
||||
assert.Contains(t, entry, "100.125.1.10")
|
||||
assert.Contains(t, entry, "testpeer.nb.internal")
|
||||
assert.Contains(t, entry, "testpeer")
|
||||
assert.Contains(t, entry, "[100.125.1.10]:22")
|
||||
assert.Contains(t, entry, "[100.125.1.10]:22022")
|
||||
|
||||
// Should contain the public key
|
||||
keyString := string(ssh.MarshalAuthorizedKey(parsedKey.PublicKey()))
|
||||
keyString = strings.TrimSpace(keyString)
|
||||
assert.Contains(t, entry, keyString)
|
||||
|
||||
// Should be properly formatted (hostnames followed by key)
|
||||
parts := strings.Fields(entry)
|
||||
assert.GreaterOrEqual(t, len(parts), 2, "Entry should have hostnames and key parts")
|
||||
}
|
||||
|
||||
func TestManager_DirectoryFallback(t *testing.T) {
|
||||
// Create temporary directory for test where system dirs will fail
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
// Set HOME to temp directory to control user fallback
|
||||
t.Setenv("HOME", tempDir)
|
||||
|
||||
// Create manager with non-writable system directories
|
||||
// Use paths that will fail on all systems
|
||||
var failPath string
|
||||
// Check platform-specific UserKnownHostsFile
|
||||
if runtime.GOOS == "windows" {
|
||||
failPath = "NUL:" // Special device that can't be used as directory on Windows
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile NUL")
|
||||
} else {
|
||||
failPath = "/dev/null" // Special device that can't be used as directory on Unix
|
||||
assert.Contains(t, configStr, "UserKnownHostsFile /dev/null")
|
||||
}
|
||||
|
||||
manager := &Manager{
|
||||
sshConfigDir: failPath + "/ssh_config.d", // Should fail
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: failPath + "/ssh_known_hosts.d", // Should fail
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
}
|
||||
|
||||
// Should fall back to user directory
|
||||
knownHostsPath, err := manager.setupKnownHostsFile()
|
||||
require.NoError(t, err)
|
||||
|
||||
// Get the actual user home directory as determined by os.UserHomeDir()
|
||||
userHome, err := os.UserHomeDir()
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedUserPath := filepath.Join(userHome, ".ssh", "known_hosts_netbird")
|
||||
assert.Equal(t, expectedUserPath, knownHostsPath)
|
||||
|
||||
// Verify file was created
|
||||
_, err = os.Stat(knownHostsPath)
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestGetSystemSSHPaths(t *testing.T) {
|
||||
configDir, knownHostsDir := getSystemSSHPaths()
|
||||
func TestGetSystemSSHConfigDir(t *testing.T) {
|
||||
configDir := getSystemSSHConfigDir()
|
||||
|
||||
// Paths should not be empty
|
||||
// Path should not be empty
|
||||
assert.NotEmpty(t, configDir)
|
||||
assert.NotEmpty(t, knownHostsDir)
|
||||
|
||||
// Should be absolute paths
|
||||
// Should be an absolute path
|
||||
assert.True(t, filepath.IsAbs(configDir))
|
||||
assert.True(t, filepath.IsAbs(knownHostsDir))
|
||||
|
||||
// On Unix systems, should start with /etc
|
||||
// On Windows, should contain ProgramData
|
||||
if runtime.GOOS == "windows" {
|
||||
assert.Contains(t, strings.ToLower(configDir), "programdata")
|
||||
assert.Contains(t, strings.ToLower(knownHostsDir), "programdata")
|
||||
} else {
|
||||
assert.Contains(t, configDir, "/etc/ssh")
|
||||
assert.Contains(t, knownHostsDir, "/etc/ssh")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -250,46 +92,28 @@ func TestManager_PeerLimit(t *testing.T) {
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peer keys (more than limit)
|
||||
var peerKeys []PeerHostKey
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
pubKey, err := ssh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
peerKeys = append(peerKeys, PeerHostKey{
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
HostKey: pubKey.PublicKey(),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is skipped when too many peers
|
||||
err = manager.SetupSSHClientConfig(peerKeys)
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should not be created due to peer limit
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
_, err = os.Stat(configPath)
|
||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||
|
||||
// Test that known_hosts update is also skipped
|
||||
err = manager.UpdatePeerHostKeys(peerKeys)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Known hosts should not be created due to peer limit
|
||||
knownHostsPath := filepath.Join(manager.knownHostsDir, manager.knownHostsFile)
|
||||
_, err = os.Stat(knownHostsPath)
|
||||
assert.True(t, os.IsNotExist(err), "Known hosts should not be created with too many peers")
|
||||
}
|
||||
|
||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
@@ -303,31 +127,22 @@ func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
|
||||
// Override manager paths to use temp directory
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
knownHostsDir: filepath.Join(tempDir, "ssh_known_hosts.d"),
|
||||
knownHostsFile: "99-netbird",
|
||||
userKnownHosts: "known_hosts_netbird",
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
// Generate many peer keys (more than limit)
|
||||
var peerKeys []PeerHostKey
|
||||
// Generate many peers (more than limit)
|
||||
var peers []PeerSSHInfo
|
||||
for i := 0; i < MaxPeersForSSHConfig+10; i++ {
|
||||
hostKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
|
||||
require.NoError(t, err)
|
||||
pubKey, err := ssh.ParsePrivateKey(hostKey)
|
||||
require.NoError(t, err)
|
||||
|
||||
peerKeys = append(peerKeys, PeerHostKey{
|
||||
peers = append(peers, PeerSSHInfo{
|
||||
Hostname: fmt.Sprintf("peer%d", i),
|
||||
IP: fmt.Sprintf("100.125.1.%d", i%254+1),
|
||||
FQDN: fmt.Sprintf("peer%d.nb.internal", i),
|
||||
HostKey: pubKey.PublicKey(),
|
||||
})
|
||||
}
|
||||
|
||||
// Test that SSH config generation is forced despite many peers
|
||||
err = manager.SetupSSHClientConfig(peerKeys)
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Config should be created despite peer limit due to force flag
|
||||
|
||||
22
client/ssh/config/shutdown_state.go
Normal file
22
client/ssh/config/shutdown_state.go
Normal file
@@ -0,0 +1,22 @@
|
||||
package config
|
||||
|
||||
// ShutdownState represents SSH configuration state that needs to be cleaned up.
|
||||
type ShutdownState struct {
|
||||
SSHConfigDir string
|
||||
SSHConfigFile string
|
||||
}
|
||||
|
||||
// Name returns the state name for the state manager.
|
||||
func (s *ShutdownState) Name() string {
|
||||
return "ssh_config_state"
|
||||
}
|
||||
|
||||
// Cleanup removes SSH client configuration files.
|
||||
func (s *ShutdownState) Cleanup() error {
|
||||
manager := &Manager{
|
||||
sshConfigDir: s.SSHConfigDir,
|
||||
sshConfigFile: s.SSHConfigFile,
|
||||
}
|
||||
|
||||
return manager.RemoveSSHClientConfig()
|
||||
}
|
||||
Reference in New Issue
Block a user