mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-26 20:26:39 +00:00
Reduce complexity
This commit is contained in:
@@ -370,7 +370,23 @@ func createHostKeyCallbackWithDaemonAddr(addr, daemonAddr string) (ssh.HostKeyCa
|
||||
|
||||
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
|
||||
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
// Connect to NetBird daemon using the same logic as CLI
|
||||
client, err := connectToDaemon(daemonAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defer func() {
|
||||
if err := client.Close(); err != nil {
|
||||
log.Debugf("daemon connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
addresses := buildAddressList(hostname, remote)
|
||||
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses)
|
||||
|
||||
return verifyKeyWithDaemon(client, addresses, key)
|
||||
}
|
||||
|
||||
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||
defer cancel()
|
||||
|
||||
@@ -382,64 +398,70 @@ func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey,
|
||||
)
|
||||
if err != nil {
|
||||
log.Debugf("failed to connect to NetBird daemon at %s: %v", daemonAddr, err)
|
||||
return fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
||||
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("daemon connection close error: %v", err)
|
||||
}
|
||||
}()
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
// Try both hostname and IP address from remote.String()
|
||||
func buildAddressList(hostname string, remote net.Addr) []string {
|
||||
addresses := []string{hostname}
|
||||
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
|
||||
if host != hostname {
|
||||
addresses = append(addresses, host)
|
||||
}
|
||||
}
|
||||
return addresses
|
||||
}
|
||||
|
||||
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses)
|
||||
func verifyKeyWithDaemon(conn *grpc.ClientConn, addresses []string, key ssh.PublicKey) error {
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
for _, addr := range addresses {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
||||
PeerAddress: addr,
|
||||
})
|
||||
cancel()
|
||||
|
||||
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("daemon query error for %s: %v", addr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
if !response.GetFound() {
|
||||
log.Debugf("SSH host key not found in daemon for address: %s", addr)
|
||||
continue
|
||||
}
|
||||
|
||||
// Parse the stored SSH host key
|
||||
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(response.GetSshHostKey())
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
|
||||
continue
|
||||
}
|
||||
|
||||
// Compare the keys
|
||||
if key.Type() == storedKey.Type() && string(key.Marshal()) == string(storedKey.Marshal()) {
|
||||
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
|
||||
if err := checkAddressKey(client, addr, key); err == nil {
|
||||
return nil
|
||||
} else {
|
||||
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), key.Type())
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("SSH host key not found or does not match in NetBird daemon")
|
||||
}
|
||||
|
||||
func checkAddressKey(client proto.DaemonServiceClient, addr string, key ssh.PublicKey) error {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
|
||||
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
|
||||
PeerAddress: addr,
|
||||
})
|
||||
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
|
||||
|
||||
if err != nil {
|
||||
log.Debugf("daemon query error for %s: %v", addr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if !response.GetFound() {
|
||||
log.Debugf("SSH host key not found in daemon for address: %s", addr)
|
||||
return fmt.Errorf("key not found")
|
||||
}
|
||||
|
||||
return compareKeys(response.GetSshHostKey(), key, addr)
|
||||
}
|
||||
|
||||
func compareKeys(storedKeyData []byte, presentedKey ssh.PublicKey, addr string) error {
|
||||
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
|
||||
if err != nil {
|
||||
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
|
||||
return err
|
||||
}
|
||||
|
||||
if presentedKey.Type() == storedKey.Type() && string(presentedKey.Marshal()) == string(storedKey.Marshal()) {
|
||||
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), presentedKey.Type())
|
||||
return fmt.Errorf("key mismatch")
|
||||
}
|
||||
|
||||
// getKnownHostsFiles returns paths to known_hosts files in order of preference
|
||||
func getKnownHostsFiles() []string {
|
||||
var files []string
|
||||
@@ -469,41 +491,49 @@ func getKnownHostsFiles() []string {
|
||||
// createHostKeyCallbackWithOptions creates a host key verification callback with custom options
|
||||
func createHostKeyCallbackWithOptions(addr string, opts DialOptions) (ssh.HostKeyCallback, error) {
|
||||
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
|
||||
// First try to get host key from NetBird daemon (if daemon address provided)
|
||||
if opts.DaemonAddr != "" {
|
||||
if err := verifyHostKeyViaDaemon(hostname, remote, key, opts.DaemonAddr); err == nil {
|
||||
return nil
|
||||
}
|
||||
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Fallback to known_hosts files
|
||||
var knownHostsFiles []string
|
||||
|
||||
if opts.KnownHostsFile != "" {
|
||||
knownHostsFiles = append(knownHostsFiles, opts.KnownHostsFile)
|
||||
} else {
|
||||
knownHostsFiles = getKnownHostsFiles()
|
||||
}
|
||||
|
||||
var hostKeyCallbacks []ssh.HostKeyCallback
|
||||
|
||||
for _, file := range knownHostsFiles {
|
||||
if callback, err := knownhosts.New(file); err == nil {
|
||||
hostKeyCallbacks = append(hostKeyCallbacks, callback)
|
||||
}
|
||||
}
|
||||
|
||||
// Try each known_hosts callback
|
||||
for _, callback := range hostKeyCallbacks {
|
||||
if err := callback(hostname, remote, key); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
|
||||
return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile)
|
||||
}, nil
|
||||
}
|
||||
|
||||
func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
|
||||
if daemonAddr == "" {
|
||||
return fmt.Errorf("no daemon address")
|
||||
}
|
||||
return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr)
|
||||
}
|
||||
|
||||
func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error {
|
||||
knownHostsFiles := getKnownHostsFilesList(knownHostsFile)
|
||||
hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles)
|
||||
|
||||
for _, callback := range hostKeyCallbacks {
|
||||
if err := callback(hostname, remote, key); err == nil {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
|
||||
}
|
||||
|
||||
func getKnownHostsFilesList(knownHostsFile string) []string {
|
||||
if knownHostsFile != "" {
|
||||
return []string{knownHostsFile}
|
||||
}
|
||||
return getKnownHostsFiles()
|
||||
}
|
||||
|
||||
func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback {
|
||||
var hostKeyCallbacks []ssh.HostKeyCallback
|
||||
for _, file := range knownHostsFiles {
|
||||
if callback, err := knownhosts.New(file); err == nil {
|
||||
hostKeyCallbacks = append(hostKeyCallbacks, callback)
|
||||
}
|
||||
}
|
||||
return hostKeyCallbacks
|
||||
}
|
||||
|
||||
// createSSHKeyAuth creates SSH key authentication from a private key file
|
||||
func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
|
||||
keyData, err := os.ReadFile(keyFile)
|
||||
|
||||
Reference in New Issue
Block a user