Add ssh authenatication with jwt (#4550)

This commit is contained in:
Viktor Liu
2025-10-07 23:38:27 +02:00
committed by GitHub
parent 7e0bbaaa3c
commit d9efe4e944
50 changed files with 4429 additions and 2336 deletions

View File

@@ -5,10 +5,8 @@ import (
"errors"
"fmt"
"net/netip"
"runtime"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
@@ -21,9 +19,6 @@ import (
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
RemoveAuthorizedKey(peer string)
AddAuthorizedKey(peer, newKey string) error
SetSocketFilter(ifIdx int)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -44,22 +39,6 @@ func (e *Engine) setupSSHPortRedirection() error {
return nil
}
func (e *Engine) setupSSHSocketFilter(server sshServer) error {
if runtime.GOOS != "linux" {
return nil
}
netInterface := e.wgInterface.ToInterface()
if netInterface == nil {
return errors.New("failed to get WireGuard network interface")
}
server.SetSocketFilter(netInterface.Index)
log.Debugf("SSH socket filter configured for interface %s (index: %d)", netInterface.Name, netInterface.Index)
return nil
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Info("SSH server is disabled because inbound connections are blocked")
@@ -83,66 +62,76 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
return nil
}
return e.startSSHServer()
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
return e.startSSHServer(nil)
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}
return e.startSSHServer(jwtConfig)
}
return errors.New("SSH server requires valid JWT configuration")
}
// updateSSHKnownHosts updates the SSH known_hosts file with peer host keys for OpenSSH client
func (e *Engine) updateSSHKnownHosts(remotePeers []*mgmProto.RemotePeerConfig) error {
peerKeys := e.extractPeerHostKeys(remotePeers)
if len(peerKeys) == 0 {
log.Debug("no SSH-enabled peers found, skipping known_hosts update")
// updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update")
return nil
}
if err := e.updateKnownHostsFile(peerKeys); err != nil {
return err
configMgr := sshconfig.New()
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
return nil // Don't fail engine startup on SSH config issues
}
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
SSHConfigDir: configMgr.GetSSHConfigDir(),
SSHConfigFile: configMgr.GetSSHConfigFile(),
}); err != nil {
log.Warnf("failed to update SSH config state: %v", err)
}
e.updateSSHClientConfig(peerKeys)
log.Debugf("updated SSH known_hosts with %d peer host keys", len(peerKeys))
return nil
}
// extractPeerHostKeys extracts SSH host keys from peer configurations
func (e *Engine) extractPeerHostKeys(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerHostKey {
var peerKeys []sshconfig.PeerHostKey
// extractPeerSSHInfo extracts SSH information from peer configurations
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
var peerInfo []sshconfig.PeerSSHInfo
for _, peerConfig := range remotePeers {
peerHostKey, ok := e.parsePeerHostKey(peerConfig)
if ok {
peerKeys = append(peerKeys, peerHostKey)
if peerConfig.GetSshConfig() == nil {
continue
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
continue
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
})
}
return peerKeys
}
// parsePeerHostKey parses a single peer's SSH host key configuration
func (e *Engine) parsePeerHostKey(peerConfig *mgmProto.RemotePeerConfig) (sshconfig.PeerHostKey, bool) {
if peerConfig.GetSshConfig() == nil {
return sshconfig.PeerHostKey{}, false
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
return sshconfig.PeerHostKey{}, false
}
hostKey, _, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyBytes)
if err != nil {
log.Warnf("failed to parse SSH public key for peer %s: %v", peerConfig.GetWgPubKey(), err)
return sshconfig.PeerHostKey{}, false
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
return sshconfig.PeerHostKey{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
HostKey: hostKey,
}, true
return peerInfo
}
// extractPeerIP extracts IP address from peer's allowed IPs
@@ -171,25 +160,6 @@ func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
return ""
}
// updateKnownHostsFile updates the SSH known_hosts file
func (e *Engine) updateKnownHostsFile(peerKeys []sshconfig.PeerHostKey) error {
configMgr := sshconfig.NewManager()
if err := configMgr.UpdatePeerHostKeys(peerKeys); err != nil {
return fmt.Errorf("update peer host keys: %w", err)
}
return nil
}
// updateSSHClientConfig updates SSH client configuration with peer hostnames
func (e *Engine) updateSSHClientConfig(peerKeys []sshconfig.PeerHostKey) {
configMgr := sshconfig.NewManager()
if err := configMgr.SetupSSHClientConfig(peerKeys); err != nil {
log.Warnf("failed to update SSH client config with peer hostnames: %v", err)
} else {
log.Debugf("updated SSH client config with %d peer hostnames", len(peerKeys))
}
}
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
for _, peerConfig := range remotePeers {
@@ -210,30 +180,51 @@ func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig)
log.Debugf("updated peer SSH host keys for daemon API access")
}
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
e.syncMsgMux.Lock()
statusRecorder := e.statusRecorder
e.syncMsgMux.Unlock()
if statusRecorder == nil {
return nil, false
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if len(peerState.SSHHostKey) > 0 {
return peerState.SSHHostKey, true
}
return nil, false
}
}
return nil, false
}
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
configMgr := sshconfig.NewManager()
configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil {
log.Warnf("failed to remove SSH client config: %v", err)
} else {
log.Debugf("SSH client config cleanup completed")
}
if err := configMgr.RemoveKnownHostsFile(); err != nil {
log.Warnf("failed to remove SSH known_hosts: %v", err)
} else {
log.Debugf("SSH known_hosts cleanup completed")
}
}
// startSSHServer initializes and starts the SSH server with proper configuration.
func (e *Engine) startSSHServer() error {
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
server := sshserver.New(e.config.SSHKey)
serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
@@ -259,15 +250,10 @@ func (e *Engine) startSSHServer() error {
log.Warnf("failed to setup SSH port redirection: %v", err)
}
if err := e.setupSSHSocketFilter(server); err != nil {
return fmt.Errorf("set socket filter: %w", err)
}
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
return nil
}