Files
netbird/client/internal/engine_ssh.go
2025-07-02 19:35:19 +02:00

360 lines
10 KiB
Go

package internal
import (
"context"
"errors"
"fmt"
"net/netip"
"runtime"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
sshserver "github.com/netbirdio/netbird/client/ssh/server"
mgmProto "github.com/netbirdio/netbird/management/proto"
)
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
RemoveAuthorizedKey(peer string)
AddAuthorizedKey(peer, newKey string) error
SetSocketFilter(ifIdx int)
SetupSSHClientConfigWithPeers(peerKeys []sshconfig.PeerHostKey) error
}
func (e *Engine) setupSSHPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.AddInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
return fmt.Errorf("add SSH port redirection: %w", err)
}
log.Infof("SSH port redirection enabled: %s:22 -> %s:22022", localAddr, localAddr)
return nil
}
func (e *Engine) setupSSHSocketFilter(server sshServer) error {
if runtime.GOOS != "linux" {
return nil
}
netInterface := e.wgInterface.ToInterface()
if netInterface == nil {
return errors.New("failed to get WireGuard network interface")
}
server.SetSocketFilter(netInterface.Index)
log.Debugf("SSH socket filter configured for interface %s (index: %d)", netInterface.Name, netInterface.Index)
return nil
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Info("SSH server is disabled because inbound connections are blocked")
return e.stopSSHServer()
}
if !e.config.ServerSSHAllowed {
log.Info("SSH server is disabled in config")
return e.stopSSHServer()
}
if !sshConf.GetSshEnabled() {
if e.config.ServerSSHAllowed {
log.Info("SSH server is locally allowed but disabled by management server")
}
return e.stopSSHServer()
}
if e.sshServer != nil {
log.Debug("SSH server is already running")
return nil
}
return e.startSSHServer()
}
// updateSSHKnownHosts updates the SSH known_hosts file with peer host keys for OpenSSH client
func (e *Engine) updateSSHKnownHosts(remotePeers []*mgmProto.RemotePeerConfig) error {
peerKeys := e.extractPeerHostKeys(remotePeers)
if len(peerKeys) == 0 {
log.Debug("no SSH-enabled peers found, skipping known_hosts update")
return nil
}
if err := e.updateKnownHostsFile(peerKeys); err != nil {
return err
}
e.updateSSHClientConfig(peerKeys)
log.Debugf("updated SSH known_hosts with %d peer host keys", len(peerKeys))
return nil
}
// extractPeerHostKeys extracts SSH host keys from peer configurations
func (e *Engine) extractPeerHostKeys(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerHostKey {
var peerKeys []sshconfig.PeerHostKey
for _, peerConfig := range remotePeers {
peerHostKey, ok := e.parsePeerHostKey(peerConfig)
if ok {
peerKeys = append(peerKeys, peerHostKey)
}
}
return peerKeys
}
// parsePeerHostKey parses a single peer's SSH host key configuration
func (e *Engine) parsePeerHostKey(peerConfig *mgmProto.RemotePeerConfig) (sshconfig.PeerHostKey, bool) {
if peerConfig.GetSshConfig() == nil {
return sshconfig.PeerHostKey{}, false
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
return sshconfig.PeerHostKey{}, false
}
hostKey, _, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyBytes)
if err != nil {
log.Warnf("failed to parse SSH public key for peer %s: %v", peerConfig.GetWgPubKey(), err)
return sshconfig.PeerHostKey{}, false
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
return sshconfig.PeerHostKey{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
HostKey: hostKey,
}, true
}
// extractPeerIP extracts IP address from peer's allowed IPs
func (e *Engine) extractPeerIP(peerConfig *mgmProto.RemotePeerConfig) string {
if len(peerConfig.GetAllowedIps()) == 0 {
return ""
}
if prefix, err := netip.ParsePrefix(peerConfig.GetAllowedIps()[0]); err == nil {
return prefix.Addr().String()
}
return ""
}
// extractHostname extracts short hostname from FQDN
func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
fqdn := peerConfig.GetFqdn()
if fqdn == "" {
return ""
}
parts := strings.Split(fqdn, ".")
if len(parts) > 0 && parts[0] != "" {
return parts[0]
}
return ""
}
// updateKnownHostsFile updates the SSH known_hosts file
func (e *Engine) updateKnownHostsFile(peerKeys []sshconfig.PeerHostKey) error {
configMgr := sshconfig.NewManager()
if err := configMgr.UpdatePeerHostKeys(peerKeys); err != nil {
return fmt.Errorf("update peer host keys: %w", err)
}
return nil
}
// updateSSHClientConfig updates SSH client configuration with peer hostnames
func (e *Engine) updateSSHClientConfig(peerKeys []sshconfig.PeerHostKey) {
if e.sshServer == nil {
return
}
if err := e.sshServer.SetupSSHClientConfigWithPeers(peerKeys); err != nil {
log.Warnf("failed to update SSH client config with peer hostnames: %v", err)
} else {
log.Debugf("updated SSH client config with %d peer hostnames", len(peerKeys))
}
}
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
for _, peerConfig := range remotePeers {
if peerConfig.GetSshConfig() == nil {
continue
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
continue
}
if err := e.statusRecorder.UpdatePeerSSHHostKey(peerConfig.GetWgPubKey(), sshPubKeyBytes); err != nil {
log.Warnf("failed to update SSH host key for peer %s: %v", peerConfig.GetWgPubKey(), err)
}
}
log.Debugf("updated peer SSH host keys for daemon API access")
}
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
configMgr := sshconfig.NewManager()
if err := configMgr.RemoveSSHClientConfig(); err != nil {
log.Warnf("failed to remove SSH client config: %v", err)
} else {
log.Debugf("SSH client config cleanup completed")
}
if err := configMgr.RemoveKnownHostsFile(); err != nil {
log.Warnf("failed to remove SSH known_hosts: %v", err)
} else {
log.Debugf("SSH known_hosts cleanup completed")
}
}
// startSSHServer initializes and starts the SSH server with proper configuration.
func (e *Engine) startSSHServer() error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
server := sshserver.New(e.config.SSHKey)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
netbirdIP := wgAddr.IP
listenAddr := netip.AddrPortFrom(netbirdIP, sshserver.InternalSSHPort)
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
server.SetNetstackNet(netstackNet)
if registrar, ok := e.firewall.(interface {
RegisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.RegisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
log.Debugf("registered SSH service with netstack for TCP:%d", sshserver.InternalSSHPort)
}
}
e.configureSSHServer(server)
e.sshServer = server
if err := e.setupSSHPortRedirection(); err != nil {
log.Warnf("failed to setup SSH port redirection: %v", err)
}
if err := e.setupSSHSocketFilter(server); err != nil {
return fmt.Errorf("set socket filter: %w", err)
}
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
if err := server.SetupSSHClientConfig(); err != nil {
log.Warnf("failed to setup SSH client config: %v", err)
}
return nil
}
// configureSSHServer applies SSH configuration options to the server.
func (e *Engine) configureSSHServer(server *sshserver.Server) {
if e.config.EnableSSHRoot != nil && *e.config.EnableSSHRoot {
server.SetAllowRootLogin(true)
log.Info("SSH root login enabled")
} else {
server.SetAllowRootLogin(false)
log.Info("SSH root login disabled (default)")
}
if e.config.EnableSSHSFTP != nil && *e.config.EnableSSHSFTP {
server.SetAllowSFTP(true)
log.Info("SSH SFTP subsystem enabled")
} else {
server.SetAllowSFTP(false)
log.Info("SSH SFTP subsystem disabled (default)")
}
if e.config.EnableSSHLocalPortForwarding != nil && *e.config.EnableSSHLocalPortForwarding {
server.SetAllowLocalPortForwarding(true)
log.Info("SSH local port forwarding enabled")
} else {
server.SetAllowLocalPortForwarding(false)
log.Info("SSH local port forwarding disabled (default)")
}
if e.config.EnableSSHRemotePortForwarding != nil && *e.config.EnableSSHRemotePortForwarding {
server.SetAllowRemotePortForwarding(true)
log.Info("SSH remote port forwarding enabled")
} else {
server.SetAllowRemotePortForwarding(false)
log.Info("SSH remote port forwarding disabled (default)")
}
}
func (e *Engine) cleanupSSHPortRedirection() error {
if e.firewall == nil || e.wgInterface == nil {
return nil
}
localAddr := e.wgInterface.Address().IP
if !localAddr.IsValid() {
return errors.New("invalid local NetBird address")
}
if err := e.firewall.RemoveInboundDNAT(localAddr, firewallManager.ProtocolTCP, 22, 22022); err != nil {
return fmt.Errorf("remove SSH port redirection: %w", err)
}
log.Debugf("SSH port redirection removed: %s:22 -> %s:22022", localAddr, localAddr)
return nil
}
func (e *Engine) stopSSHServer() error {
if e.sshServer == nil {
return nil
}
if err := e.cleanupSSHPortRedirection(); err != nil {
log.Warnf("failed to cleanup SSH port redirection: %v", err)
}
if netstackNet := e.wgInterface.GetNet(); netstackNet != nil {
if registrar, ok := e.firewall.(interface {
UnregisterNetstackService(protocol nftypes.Protocol, port uint16)
}); ok {
registrar.UnregisterNetstackService(nftypes.TCP, sshserver.InternalSSHPort)
log.Debugf("unregistered SSH service from netstack for TCP:%d", sshserver.InternalSSHPort)
}
}
log.Info("stopping SSH server")
err := e.sshServer.Stop()
e.sshServer = nil
if err != nil {
return fmt.Errorf("stop: %w", err)
}
return nil
}