mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client,management] Rewrite the SSH feature (#4015)
This commit is contained in:
171
client/ssh/common.go
Normal file
171
client/ssh/common.go
Normal file
@@ -0,0 +1,171 @@
|
||||
package ssh
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
const (
|
||||
NetBirdSSHConfigFile = "99-netbird.conf"
|
||||
|
||||
UnixSSHConfigDir = "/etc/ssh/ssh_config.d"
|
||||
WindowsSSHConfigDir = "ssh/ssh_config.d"
|
||||
)
|
||||
|
||||
var (
|
||||
// ErrPeerNotFound indicates the peer was not found in the network
|
||||
ErrPeerNotFound = errors.New("peer not found in network")
|
||||
// ErrNoStoredKey indicates the peer has no stored SSH host key
|
||||
ErrNoStoredKey = errors.New("peer has no stored SSH host key")
|
||||
)
|
||||
|
||||
// HostKeyVerifier provides SSH host key verification
|
||||
type HostKeyVerifier interface {
|
||||
VerifySSHHostKey(peerAddress string, key []byte) error
|
||||
}
|
||||
|
||||
// DaemonHostKeyVerifier implements HostKeyVerifier using the NetBird daemon
|
||||
type DaemonHostKeyVerifier struct {
|
||||
client proto.DaemonServiceClient
|
||||
}
|
||||
|
||||
// NewDaemonHostKeyVerifier creates a new daemon-based host key verifier
|
||||
func NewDaemonHostKeyVerifier(client proto.DaemonServiceClient) *DaemonHostKeyVerifier {
|
||||
return &DaemonHostKeyVerifier{
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
|
||||
// VerifySSHHostKey verifies an SSH host key by querying the NetBird daemon
|
||||
func (d *DaemonHostKeyVerifier) VerifySSHHostKey(peerAddress string, presentedKey []byte) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
response, err := d.client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
||||
PeerAddress: peerAddress,
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !response.GetFound() {
|
||||
return ErrPeerNotFound
|
||||
}
|
||||
|
||||
storedKeyData := response.GetSshHostKey()
|
||||
|
||||
return VerifyHostKey(storedKeyData, presentedKey, peerAddress)
|
||||
}
|
||||
|
||||
// RequestJWTToken requests or retrieves a JWT token for SSH authentication
|
||||
func RequestJWTToken(ctx context.Context, client proto.DaemonServiceClient, stdout, stderr io.Writer, useCache bool, hint string) (string, error) {
|
||||
req := &proto.RequestJWTAuthRequest{}
|
||||
if hint != "" {
|
||||
req.Hint = &hint
|
||||
}
|
||||
authResponse, err := client.RequestJWTAuth(ctx, req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request JWT auth: %w", err)
|
||||
}
|
||||
|
||||
if useCache && authResponse.CachedToken != "" {
|
||||
log.Debug("Using cached authentication token")
|
||||
return authResponse.CachedToken, nil
|
||||
}
|
||||
|
||||
if stderr != nil {
|
||||
_, _ = fmt.Fprintln(stderr, "SSH authentication required.")
|
||||
_, _ = fmt.Fprintf(stderr, "Please visit: %s\n", authResponse.VerificationURIComplete)
|
||||
if authResponse.UserCode != "" {
|
||||
_, _ = fmt.Fprintf(stderr, "Or visit: %s and enter code: %s\n", authResponse.VerificationURI, authResponse.UserCode)
|
||||
}
|
||||
_, _ = fmt.Fprintln(stderr, "Waiting for authentication...")
|
||||
}
|
||||
|
||||
tokenResponse, err := client.WaitJWTToken(ctx, &proto.WaitJWTTokenRequest{
|
||||
DeviceCode: authResponse.DeviceCode,
|
||||
UserCode: authResponse.UserCode,
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wait for JWT token: %w", err)
|
||||
}
|
||||
|
||||
if stdout != nil {
|
||||
_, _ = fmt.Fprintln(stdout, "Authentication successful!")
|
||||
}
|
||||
return tokenResponse.Token, nil
|
||||
}
|
||||
|
||||
// VerifyHostKey verifies an SSH host key against stored peer key data.
|
||||
// Returns nil only if the presented key matches the stored key.
|
||||
// Returns ErrNoStoredKey if storedKeyData is empty.
|
||||
// Returns an error if the keys don't match or if parsing fails.
|
||||
func VerifyHostKey(storedKeyData []byte, presentedKey []byte, peerAddress string) error {
|
||||
if len(storedKeyData) == 0 {
|
||||
return ErrNoStoredKey
|
||||
}
|
||||
|
||||
storedPubKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse stored SSH key for %s: %w", peerAddress, err)
|
||||
}
|
||||
|
||||
if !bytes.Equal(presentedKey, storedPubKey.Marshal()) {
|
||||
return fmt.Errorf("SSH host key mismatch for %s", peerAddress)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// AddJWTAuth prepends JWT password authentication to existing auth methods.
|
||||
// This ensures JWT auth is tried first while preserving any existing auth methods.
|
||||
func AddJWTAuth(config *ssh.ClientConfig, jwtToken string) *ssh.ClientConfig {
|
||||
configWithJWT := *config
|
||||
configWithJWT.Auth = append([]ssh.AuthMethod{ssh.Password(jwtToken)}, config.Auth...)
|
||||
return &configWithJWT
|
||||
}
|
||||
|
||||
// CreateHostKeyCallback creates an SSH host key verification callback using the provided verifier.
|
||||
// It tries multiple addresses (hostname, IP) for the peer before failing.
|
||||
func CreateHostKeyCallback(verifier HostKeyVerifier) ssh.HostKeyCallback {
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
addresses := buildAddressList(hostname, remote)
|
||||
presentedKey := key.Marshal()
|
||||
|
||||
for _, addr := range addresses {
|
||||
if err := verifier.VerifySSHHostKey(addr, presentedKey); err != nil {
|
||||
if errors.Is(err, ErrPeerNotFound) {
|
||||
// Try other addresses for this peer
|
||||
continue
|
||||
}
|
||||
return err
|
||||
}
|
||||
// Verified
|
||||
return nil
|
||||
}
|
||||
|
||||
return fmt.Errorf("SSH host key verification failed: peer %s not found in network", hostname)
|
||||
}
|
||||
}
|
||||
|
||||
// buildAddressList creates a list of addresses to check for host key verification.
|
||||
// It includes the original hostname and extracts the host part from the remote address if different.
|
||||
func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
addresses := []string{hostname}
|
||||
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
||||
if host != hostname {
|
||||
addresses = append(addresses, host)
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
Reference in New Issue
Block a user