Add ssh authenatication with jwt (#4550)

This commit is contained in:
Viktor Liu
2025-10-07 23:38:27 +02:00
committed by GitHub
parent 7e0bbaaa3c
commit d9efe4e944
50 changed files with 4429 additions and 2336 deletions

View File

@@ -87,7 +87,6 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRout
func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
rules := d.squashAcceptRules(networkMap)
// if we got empty rules list but management not set networkMap.FirewallRulesIsEmpty flag
// we have old version of management without rules handling, we should allow all traffic
if len(networkMap.FirewallRules) == 0 && !networkMap.FirewallRulesIsEmpty {
@@ -350,7 +349,7 @@ func (d *DefaultManager) getPeerRuleID(
//
// NOTE: It will not squash two rules for same protocol if one covers all peers in the network,
// but other has port definitions or has drop policy.
func (d *DefaultManager) squashAcceptRules(networkMap *mgmProto.NetworkMap, ) []*mgmProto.FirewallRule {
func (d *DefaultManager) squashAcceptRules(networkMap *mgmProto.NetworkMap) []*mgmProto.FirewallRule {
totalIPs := 0
for _, p := range append(networkMap.RemotePeers, networkMap.OfflinePeers...) {
for range p.AllowedIps {

View File

@@ -25,6 +25,7 @@ import (
"github.com/netbirdio/netbird/client/internal/peer"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/stdnet"
nbnet "github.com/netbirdio/netbird/client/net"
cProto "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
@@ -34,7 +35,6 @@ import (
relayClient "github.com/netbirdio/netbird/shared/relay/client"
signal "github.com/netbirdio/netbird/shared/signal/client"
"github.com/netbirdio/netbird/util"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/version"
)
@@ -437,6 +437,7 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
EnableSSHSFTP: config.EnableSSHSFTP,
EnableSSHLocalPortForwarding: config.EnableSSHLocalPortForwarding,
EnableSSHRemotePortForwarding: config.EnableSSHRemotePortForwarding,
DisableSSHAuth: config.DisableSSHAuth,
DNSRouteInterval: config.DNSRouteInterval,
DisableClientRoutes: config.DisableClientRoutes,
@@ -527,6 +528,7 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -49,6 +49,7 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
cProto "github.com/netbirdio/netbird/client/proto"
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
"github.com/netbirdio/netbird/shared/management/domain"
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
@@ -117,6 +118,7 @@ type EngineConfig struct {
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DNSRouteInterval time.Duration
@@ -264,6 +266,7 @@ func NewEngine(
path = mobileDep.StateFilePath
}
engine.stateManager = statemanager.New(path)
engine.stateManager.RegisterState(&sshconfig.ShutdownState{})
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
return engine
@@ -676,14 +679,10 @@ func (e *Engine) removeAllPeers() error {
return nil
}
// removePeer closes an existing peer connection, removes a peer, and clears authorized key of the SSH server
// removePeer closes an existing peer connection and removes a peer
func (e *Engine) removePeer(peerKey string) error {
log.Debugf("removing peer from engine %s", peerKey)
if e.sshServer != nil {
e.sshServer.RemoveAuthorizedKey(peerKey)
}
e.connMgr.RemovePeerConn(peerKey)
err := e.statusRecorder.RemovePeer(peerKey)
@@ -859,6 +858,7 @@ func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
if err := e.mgmClient.SyncMeta(info); err != nil {
@@ -920,6 +920,7 @@ func (e *Engine) receiveManagementEvents() {
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
err = e.mgmClient.Sync(e.ctx, info, e.handleSync)
@@ -1074,24 +1075,10 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
e.statusRecorder.FinishPeerListModifications()
// update SSHServer by adding remote peer SSH keys
if e.sshServer != nil {
for _, config := range networkMap.GetRemotePeers() {
if config.GetSshConfig() != nil && config.GetSshConfig().GetSshPubKey() != nil {
err := e.sshServer.AddAuthorizedKey(config.WgPubKey, string(config.GetSshConfig().GetSshPubKey()))
if err != nil {
log.Warnf("failed adding authorized key to SSH DefaultServer %v", err)
}
}
}
}
// update peer SSH host keys in status recorder for daemon API access
e.updatePeerSSHHostKeys(networkMap.GetRemotePeers())
// update SSH client known_hosts with peer host keys for OpenSSH client
if err := e.updateSSHKnownHosts(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH known_hosts: %v", err)
if err := e.updateSSHClientConfig(networkMap.GetRemotePeers()); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
}
}
@@ -1480,6 +1467,7 @@ func (e *Engine) readInitialSettings() ([]*route.Route, *nbdns.Config, bool, err
e.config.EnableSSHSFTP,
e.config.EnableSSHLocalPortForwarding,
e.config.EnableSSHRemotePortForwarding,
e.config.DisableSSHAuth,
)
netMap, err := e.mgmClient.GetNetworkMap(info)

View File

@@ -5,10 +5,8 @@ import (
"errors"
"fmt"
"net/netip"
"runtime"
"strings"
"github.com/gliderlabs/ssh"
log "github.com/sirupsen/logrus"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
@@ -21,9 +19,6 @@ import (
type sshServer interface {
Start(ctx context.Context, addr netip.AddrPort) error
Stop() error
RemoveAuthorizedKey(peer string)
AddAuthorizedKey(peer, newKey string) error
SetSocketFilter(ifIdx int)
}
func (e *Engine) setupSSHPortRedirection() error {
@@ -44,22 +39,6 @@ func (e *Engine) setupSSHPortRedirection() error {
return nil
}
func (e *Engine) setupSSHSocketFilter(server sshServer) error {
if runtime.GOOS != "linux" {
return nil
}
netInterface := e.wgInterface.ToInterface()
if netInterface == nil {
return errors.New("failed to get WireGuard network interface")
}
server.SetSocketFilter(netInterface.Index)
log.Debugf("SSH socket filter configured for interface %s (index: %d)", netInterface.Name, netInterface.Index)
return nil
}
func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
if e.config.BlockInbound {
log.Info("SSH server is disabled because inbound connections are blocked")
@@ -83,66 +62,76 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
return nil
}
return e.startSSHServer()
if e.config.DisableSSHAuth != nil && *e.config.DisableSSHAuth {
log.Info("starting SSH server without JWT authentication (authentication disabled by config)")
return e.startSSHServer(nil)
}
if protoJWT := sshConf.GetJwtConfig(); protoJWT != nil {
jwtConfig := &sshserver.JWTConfig{
Issuer: protoJWT.GetIssuer(),
Audience: protoJWT.GetAudience(),
KeysLocation: protoJWT.GetKeysLocation(),
MaxTokenAge: protoJWT.GetMaxTokenAge(),
}
return e.startSSHServer(jwtConfig)
}
return errors.New("SSH server requires valid JWT configuration")
}
// updateSSHKnownHosts updates the SSH known_hosts file with peer host keys for OpenSSH client
func (e *Engine) updateSSHKnownHosts(remotePeers []*mgmProto.RemotePeerConfig) error {
peerKeys := e.extractPeerHostKeys(remotePeers)
if len(peerKeys) == 0 {
log.Debug("no SSH-enabled peers found, skipping known_hosts update")
// updateSSHClientConfig updates the SSH client configuration with peer information
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
peerInfo := e.extractPeerSSHInfo(remotePeers)
if len(peerInfo) == 0 {
log.Debug("no SSH-enabled peers found, skipping SSH config update")
return nil
}
if err := e.updateKnownHostsFile(peerKeys); err != nil {
return err
configMgr := sshconfig.New()
if err := configMgr.SetupSSHClientConfig(peerInfo); err != nil {
log.Warnf("failed to update SSH client config: %v", err)
return nil // Don't fail engine startup on SSH config issues
}
log.Debugf("updated SSH client config with %d peers", len(peerInfo))
if err := e.stateManager.UpdateState(&sshconfig.ShutdownState{
SSHConfigDir: configMgr.GetSSHConfigDir(),
SSHConfigFile: configMgr.GetSSHConfigFile(),
}); err != nil {
log.Warnf("failed to update SSH config state: %v", err)
}
e.updateSSHClientConfig(peerKeys)
log.Debugf("updated SSH known_hosts with %d peer host keys", len(peerKeys))
return nil
}
// extractPeerHostKeys extracts SSH host keys from peer configurations
func (e *Engine) extractPeerHostKeys(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerHostKey {
var peerKeys []sshconfig.PeerHostKey
// extractPeerSSHInfo extracts SSH information from peer configurations
func (e *Engine) extractPeerSSHInfo(remotePeers []*mgmProto.RemotePeerConfig) []sshconfig.PeerSSHInfo {
var peerInfo []sshconfig.PeerSSHInfo
for _, peerConfig := range remotePeers {
peerHostKey, ok := e.parsePeerHostKey(peerConfig)
if ok {
peerKeys = append(peerKeys, peerHostKey)
if peerConfig.GetSshConfig() == nil {
continue
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
continue
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
peerInfo = append(peerInfo, sshconfig.PeerSSHInfo{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
})
}
return peerKeys
}
// parsePeerHostKey parses a single peer's SSH host key configuration
func (e *Engine) parsePeerHostKey(peerConfig *mgmProto.RemotePeerConfig) (sshconfig.PeerHostKey, bool) {
if peerConfig.GetSshConfig() == nil {
return sshconfig.PeerHostKey{}, false
}
sshPubKeyBytes := peerConfig.GetSshConfig().GetSshPubKey()
if len(sshPubKeyBytes) == 0 {
return sshconfig.PeerHostKey{}, false
}
hostKey, _, _, _, err := ssh.ParseAuthorizedKey(sshPubKeyBytes)
if err != nil {
log.Warnf("failed to parse SSH public key for peer %s: %v", peerConfig.GetWgPubKey(), err)
return sshconfig.PeerHostKey{}, false
}
peerIP := e.extractPeerIP(peerConfig)
hostname := e.extractHostname(peerConfig)
return sshconfig.PeerHostKey{
Hostname: hostname,
IP: peerIP,
FQDN: peerConfig.GetFqdn(),
HostKey: hostKey,
}, true
return peerInfo
}
// extractPeerIP extracts IP address from peer's allowed IPs
@@ -171,25 +160,6 @@ func (e *Engine) extractHostname(peerConfig *mgmProto.RemotePeerConfig) string {
return ""
}
// updateKnownHostsFile updates the SSH known_hosts file
func (e *Engine) updateKnownHostsFile(peerKeys []sshconfig.PeerHostKey) error {
configMgr := sshconfig.NewManager()
if err := configMgr.UpdatePeerHostKeys(peerKeys); err != nil {
return fmt.Errorf("update peer host keys: %w", err)
}
return nil
}
// updateSSHClientConfig updates SSH client configuration with peer hostnames
func (e *Engine) updateSSHClientConfig(peerKeys []sshconfig.PeerHostKey) {
configMgr := sshconfig.NewManager()
if err := configMgr.SetupSSHClientConfig(peerKeys); err != nil {
log.Warnf("failed to update SSH client config with peer hostnames: %v", err)
} else {
log.Debugf("updated SSH client config with %d peer hostnames", len(peerKeys))
}
}
// updatePeerSSHHostKeys updates peer SSH host keys in the status recorder for daemon API access
func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig) {
for _, peerConfig := range remotePeers {
@@ -210,30 +180,51 @@ func (e *Engine) updatePeerSSHHostKeys(remotePeers []*mgmProto.RemotePeerConfig)
log.Debugf("updated peer SSH host keys for daemon API access")
}
// GetPeerSSHKey returns the SSH host key for a specific peer by IP or FQDN
func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
e.syncMsgMux.Lock()
statusRecorder := e.statusRecorder
e.syncMsgMux.Unlock()
if statusRecorder == nil {
return nil, false
}
fullStatus := statusRecorder.GetFullStatus()
for _, peerState := range fullStatus.Peers {
if peerState.IP == peerAddress || peerState.FQDN == peerAddress {
if len(peerState.SSHHostKey) > 0 {
return peerState.SSHHostKey, true
}
return nil, false
}
}
return nil, false
}
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
func (e *Engine) cleanupSSHConfig() {
configMgr := sshconfig.NewManager()
configMgr := sshconfig.New()
if err := configMgr.RemoveSSHClientConfig(); err != nil {
log.Warnf("failed to remove SSH client config: %v", err)
} else {
log.Debugf("SSH client config cleanup completed")
}
if err := configMgr.RemoveKnownHostsFile(); err != nil {
log.Warnf("failed to remove SSH known_hosts: %v", err)
} else {
log.Debugf("SSH known_hosts cleanup completed")
}
}
// startSSHServer initializes and starts the SSH server with proper configuration.
func (e *Engine) startSSHServer() error {
func (e *Engine) startSSHServer(jwtConfig *sshserver.JWTConfig) error {
if e.wgInterface == nil {
return errors.New("wg interface not initialized")
}
server := sshserver.New(e.config.SSHKey)
serverConfig := &sshserver.Config{
HostKeyPEM: e.config.SSHKey,
JWT: jwtConfig,
}
server := sshserver.New(serverConfig)
wgAddr := e.wgInterface.Address()
server.SetNetworkValidation(wgAddr)
@@ -259,15 +250,10 @@ func (e *Engine) startSSHServer() error {
log.Warnf("failed to setup SSH port redirection: %v", err)
}
if err := e.setupSSHSocketFilter(server); err != nil {
return fmt.Errorf("set socket filter: %w", err)
}
if err := server.Start(e.ctx, listenAddr); err != nil {
return fmt.Errorf("start SSH server: %w", err)
}
return nil
}

View File

@@ -281,7 +281,15 @@ func TestEngine_SSH(t *testing.T) {
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}

View File

@@ -128,6 +128,7 @@ func doMgmLogin(ctx context.Context, mgmClient *mgm.GrpcClient, pubSSHKey []byte
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := mgmClient.Login(*serverKey, sysInfo, pubSSHKey, config.DNSLabels)
return serverKey, loginResp, err
@@ -158,6 +159,7 @@ func registerPeer(ctx context.Context, serverPublicKey wgtypes.Key, client *mgm.
config.EnableSSHSFTP,
config.EnableSSHLocalPortForwarding,
config.EnableSSHRemotePortForwarding,
config.DisableSSHAuth,
)
loginResp, err := client.Register(serverPublicKey, validSetupKey.String(), jwtToken, info, pubSSHKey, config.DNSLabels)
if err != nil {

View File

@@ -21,9 +21,9 @@ import (
"github.com/netbirdio/netbird/client/internal/ingressgw"
"github.com/netbirdio/netbird/client/internal/relay"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
relayClient "github.com/netbirdio/netbird/shared/relay/client"
"github.com/netbirdio/netbird/route"
)
const eventQueueSize = 10

View File

@@ -54,6 +54,7 @@ type ConfigInput struct {
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
NATExternalIPs []string
CustomDNSAddress []byte
RosenpassEnabled *bool
@@ -102,6 +103,7 @@ type Config struct {
EnableSSHSFTP *bool
EnableSSHLocalPortForwarding *bool
EnableSSHRemotePortForwarding *bool
DisableSSHAuth *bool
DisableClientRoutes bool
DisableServerRoutes bool
@@ -423,6 +425,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
updated = true
}
if input.DisableSSHAuth != nil && input.DisableSSHAuth != config.DisableSSHAuth {
if *input.DisableSSHAuth {
log.Infof("disabling SSH authentication")
} else {
log.Infof("enabling SSH authentication")
}
config.DisableSSHAuth = input.DisableSSHAuth
updated = true
}
if input.DNSRouteInterval != nil && *input.DNSRouteInterval != config.DNSRouteInterval {
log.Infof("updating DNS route interval to %s (old value %s)",
input.DNSRouteInterval.String(), config.DNSRouteInterval.String())

View File

@@ -18,8 +18,8 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/iface"
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/shared/management/domain"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (