mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
@@ -18,12 +18,16 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
)
|
||||
|
||||
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||
var ErrClientNotStarted = errors.New("client not started")
|
||||
var ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
var (
|
||||
ErrClientAlreadyStarted = errors.New("client already started")
|
||||
ErrClientNotStarted = errors.New("client not started")
|
||||
ErrEngineNotStarted = errors.New("engine not started")
|
||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||
)
|
||||
|
||||
// Client manages a netbird embedded client instance.
|
||||
type Client struct {
|
||||
@@ -238,17 +242,9 @@ func (c *Client) GetConfig() (profilemanager.Config, error) {
|
||||
// Dial dials a network address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, errors.New("engine not started")
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
nsnet, err := engine.GetNet()
|
||||
@@ -259,6 +255,11 @@ func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, e
|
||||
return nsnet.DialContext(ctx, network, address)
|
||||
}
|
||||
|
||||
// DialContext dials a network address in the netbird network with context
|
||||
func (c *Client) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
|
||||
return c.Dial(ctx, network, address)
|
||||
}
|
||||
|
||||
// ListenTCP listens on the given address in the netbird network.
|
||||
// Not applicable if the userspace networking mode is disabled.
|
||||
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||
@@ -314,18 +315,47 @@ func (c *Client) NewHTTPClient() *http.Client {
|
||||
}
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
// VerifySSHHostKey verifies an SSH host key against stored peer keys.
|
||||
// Returns nil if the key matches, ErrPeerNotFound if peer is not in network,
|
||||
// ErrNoStoredKey if peer has no stored key, or an error for verification failures.
|
||||
func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
storedKey, found := engine.GetPeerSSHKey(peerAddress)
|
||||
if !found {
|
||||
return sshcommon.ErrPeerNotFound
|
||||
}
|
||||
|
||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||
}
|
||||
|
||||
// getEngine safely retrieves the engine from the client with proper locking.
|
||||
// Returns ErrClientNotStarted if the client is not started.
|
||||
// Returns ErrEngineNotStarted if the engine is not available.
|
||||
func (c *Client) getEngine() (*internal.Engine, error) {
|
||||
c.mu.Lock()
|
||||
connect := c.connect
|
||||
if connect == nil {
|
||||
c.mu.Unlock()
|
||||
return nil, netip.Addr{}, errors.New("client not started")
|
||||
}
|
||||
c.mu.Unlock()
|
||||
|
||||
if connect == nil {
|
||||
return nil, ErrClientNotStarted
|
||||
}
|
||||
|
||||
engine := connect.Engine()
|
||||
if engine == nil {
|
||||
return nil, netip.Addr{}, errors.New("engine not started")
|
||||
return nil, ErrEngineNotStarted
|
||||
}
|
||||
|
||||
return engine, nil
|
||||
}
|
||||
|
||||
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||
engine, err := c.getEngine()
|
||||
if err != nil {
|
||||
return nil, netip.Addr{}, err
|
||||
}
|
||||
|
||||
addr, err := engine.Address()
|
||||
|
||||
Reference in New Issue
Block a user