[client,management] Rewrite the SSH feature (#4015)

This commit is contained in:
Viktor Liu
2025-11-17 17:10:41 +01:00
committed by GitHub
parent 0d79301141
commit d71a82769c
170 changed files with 18744 additions and 2853 deletions

View File

@@ -14,7 +14,6 @@ import (
"github.com/golang/mock/gomock"
"github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/stdnet"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
@@ -25,7 +24,10 @@ import (
"google.golang.org/grpc"
"google.golang.org/grpc/keepalive"
"github.com/netbirdio/netbird/client/internal/stdnet"
"github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
nbgrpc "github.com/netbirdio/netbird/management/internals/shared/grpc"
@@ -46,7 +48,7 @@ import (
icemaker "github.com/netbirdio/netbird/client/internal/peer/ice"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/routemanager"
"github.com/netbirdio/netbird/client/ssh"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/system"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
@@ -214,11 +216,13 @@ func TestMain(m *testing.M) {
}
func TestEngine_SSH(t *testing.T) {
if runtime.GOOS == "windows" {
t.Skip("skipping TestEngine_SSH")
key, err := wgtypes.GeneratePrivateKey()
if err != nil {
t.Fatal(err)
return
}
key, err := wgtypes.GeneratePrivateKey()
sshKey, err := nbssh.GeneratePrivateKey(nbssh.ED25519)
if err != nil {
t.Fatal(err)
return
@@ -240,6 +244,7 @@ func TestEngine_SSH(t *testing.T) {
WgPort: 33100,
ServerSSHAllowed: true,
MTU: iface.DefaultMTU,
SSHKey: sshKey,
},
MobileDependency{},
peer.NewRecorder("https://mgm"),
@@ -250,35 +255,8 @@ func TestEngine_SSH(t *testing.T) {
UpdateDNSServerFunc: func(serial uint64, update nbdns.Config) error { return nil },
}
var sshKeysAdded []string
var sshPeersRemoved []string
sshCtx, cancel := context.WithCancel(context.Background())
engine.sshServerFunc = func(hostKeyPEM []byte, addr string) (ssh.Server, error) {
return &ssh.MockServer{
Ctx: sshCtx,
StopFunc: func() error {
cancel()
return nil
},
StartFunc: func() error {
<-ctx.Done()
return ctx.Err()
},
AddAuthorizedKeyFunc: func(peer, newKey string) error {
sshKeysAdded = append(sshKeysAdded, newKey)
return nil
},
RemoveAuthorizedKeyFunc: func(peer string) {
sshPeersRemoved = append(sshPeersRemoved, peer)
},
}, nil
}
err = engine.Start(nil, nil)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
defer func() {
err := engine.Stop()
@@ -304,9 +282,7 @@ func TestEngine_SSH(t *testing.T) {
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
@@ -314,19 +290,24 @@ func TestEngine_SSH(t *testing.T) {
networkMap = &mgmtProto.NetworkMap{
Serial: 7,
PeerConfig: &mgmtProto.PeerConfig{Address: "100.64.0.1/24",
SshConfig: &mgmtProto.SSHConfig{SshEnabled: true}},
SshConfig: &mgmtProto.SSHConfig{
SshEnabled: true,
JwtConfig: &mgmtProto.JWTConfig{
Issuer: "test-issuer",
Audience: "test-audience",
KeysLocation: "test-keys",
MaxTokenAge: 3600,
},
}},
RemotePeers: []*mgmtProto.RemotePeerConfig{peerWithSSH},
RemotePeersIsEmpty: false,
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshKeysAdded, "ssh-ed25519 AAAAC3NzaC1lZDI1NTE5AAAAIFATYCqaQw/9id1Qkq3n16JYhDhXraI6Pc1fgB8ynEfQ")
// now remove peer
networkMap = &mgmtProto.NetworkMap{
@@ -336,13 +317,10 @@ func TestEngine_SSH(t *testing.T) {
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
// time.Sleep(250 * time.Millisecond)
assert.NotNil(t, engine.sshServer)
assert.Contains(t, sshPeersRemoved, "MNHf3Ma6z6mdLbriAJbqhX7+nM/B71lgw2+91q3LfhU=")
// now disable SSH server
networkMap = &mgmtProto.NetworkMap{
@@ -354,12 +332,70 @@ func TestEngine_SSH(t *testing.T) {
}
err = engine.updateNetworkMap(networkMap)
if err != nil {
t.Fatal(err)
}
require.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHUpdateLogic(t *testing.T) {
// Test that SSH server start/stop logic works based on config
engine := &Engine{
config: &EngineConfig{
ServerSSHAllowed: false, // Start with SSH disabled
},
syncMsgMux: &sync.Mutex{},
}
// Test SSH disabled config
sshConfig := &mgmtProto.SSHConfig{SshEnabled: false}
err := engine.updateSSH(sshConfig)
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
// Test inbound blocked
engine.config.BlockInbound = true
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
engine.config.BlockInbound = false
// Test with server SSH not allowed
err = engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
}
func TestEngine_SSHServerConsistency(t *testing.T) {
t.Run("server set only on successful creation", func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{
ServerSSHAllowed: true,
SSHKey: []byte("test-key"),
},
syncMsgMux: &sync.Mutex{},
}
engine.wgInterface = nil
err := engine.updateSSH(&mgmtProto.SSHConfig{SshEnabled: true})
assert.Error(t, err)
assert.Nil(t, engine.sshServer)
})
t.Run("cleanup handles nil gracefully", func(t *testing.T) {
engine := &Engine{
config: &EngineConfig{
ServerSSHAllowed: false,
},
syncMsgMux: &sync.Mutex{},
}
err := engine.stopSSHServer()
assert.NoError(t, err)
assert.Nil(t, engine.sshServer)
})
}
func TestEngine_UpdateNetworkMap(t *testing.T) {
@@ -1589,7 +1625,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock())
accountManager, err := server.BuildManager(context.Background(), store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}