Reduce complexity

This commit is contained in:
Viktor Liu
2025-07-02 20:43:17 +02:00
parent 4bbca28eb6
commit 96084e3a02
6 changed files with 321 additions and 228 deletions

View File

@@ -370,7 +370,23 @@ func createHostKeyCallbackWithDaemonAddr(addr, daemonAddr string) (ssh.HostKeyCa
// verifyHostKeyViaDaemon verifies SSH host key by querying the NetBird daemon
func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
// Connect to NetBird daemon using the same logic as CLI
client, err := connectToDaemon(daemonAddr)
if err != nil {
return err
}
defer func() {
if err := client.Close(); err != nil {
log.Debugf("daemon connection close error: %v", err)
}
}()
addresses := buildAddressList(hostname, remote)
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses)
return verifyKeyWithDaemon(client, addresses, key)
}
func connectToDaemon(daemonAddr string) (*grpc.ClientConn, error) {
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
defer cancel()
@@ -382,64 +398,70 @@ func verifyHostKeyViaDaemon(hostname string, remote net.Addr, key ssh.PublicKey,
)
if err != nil {
log.Debugf("failed to connect to NetBird daemon at %s: %v", daemonAddr, err)
return fmt.Errorf("failed to connect to NetBird daemon: %w", err)
return nil, fmt.Errorf("failed to connect to NetBird daemon: %w", err)
}
defer func() {
if err := conn.Close(); err != nil {
log.Debugf("daemon connection close error: %v", err)
}
}()
return conn, nil
}
client := proto.NewDaemonServiceClient(conn)
// Try both hostname and IP address from remote.String()
func buildAddressList(hostname string, remote net.Addr) []string {
addresses := []string{hostname}
if host, _, err := net.SplitHostPort(remote.String()); err == nil {
if host != hostname {
addresses = append(addresses, host)
}
}
return addresses
}
log.Debugf("verifying SSH host key for hostname=%s, remote=%s, addresses=%v", hostname, remote.String(), addresses)
func verifyKeyWithDaemon(conn *grpc.ClientConn, addresses []string, key ssh.PublicKey) error {
client := proto.NewDaemonServiceClient(conn)
for _, addr := range addresses {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: addr,
})
cancel()
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
if err != nil {
log.Debugf("daemon query error for %s: %v", addr, err)
continue
}
if !response.GetFound() {
log.Debugf("SSH host key not found in daemon for address: %s", addr)
continue
}
// Parse the stored SSH host key
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(response.GetSshHostKey())
if err != nil {
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
continue
}
// Compare the keys
if key.Type() == storedKey.Type() && string(key.Marshal()) == string(storedKey.Marshal()) {
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
if err := checkAddressKey(client, addr, key); err == nil {
return nil
} else {
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), key.Type())
}
}
return fmt.Errorf("SSH host key not found or does not match in NetBird daemon")
}
func checkAddressKey(client proto.DaemonServiceClient, addr string, key ssh.PublicKey) error {
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
response, err := client.GetPeerSSHHostKey(ctx, &proto.GetPeerSSHHostKeyRequest{
PeerAddress: addr,
})
log.Debugf("daemon query for address %s: found=%v, error=%v", addr, response != nil && response.GetFound(), err)
if err != nil {
log.Debugf("daemon query error for %s: %v", addr, err)
return err
}
if !response.GetFound() {
log.Debugf("SSH host key not found in daemon for address: %s", addr)
return fmt.Errorf("key not found")
}
return compareKeys(response.GetSshHostKey(), key, addr)
}
func compareKeys(storedKeyData []byte, presentedKey ssh.PublicKey, addr string) error {
storedKey, _, _, _, err := ssh.ParseAuthorizedKey(storedKeyData)
if err != nil {
log.Debugf("failed to parse stored SSH host key for %s: %v", addr, err)
return err
}
if presentedKey.Type() == storedKey.Type() && string(presentedKey.Marshal()) == string(storedKey.Marshal()) {
log.Debugf("SSH host key verified via NetBird daemon for %s", addr)
return nil
}
log.Debugf("SSH host key mismatch for %s: stored type=%s, presented type=%s", addr, storedKey.Type(), presentedKey.Type())
return fmt.Errorf("key mismatch")
}
// getKnownHostsFiles returns paths to known_hosts files in order of preference
func getKnownHostsFiles() []string {
var files []string
@@ -469,41 +491,49 @@ func getKnownHostsFiles() []string {
// createHostKeyCallbackWithOptions creates a host key verification callback with custom options
func createHostKeyCallbackWithOptions(addr string, opts DialOptions) (ssh.HostKeyCallback, error) {
return func(hostname string, remote net.Addr, key ssh.PublicKey) error {
// First try to get host key from NetBird daemon (if daemon address provided)
if opts.DaemonAddr != "" {
if err := verifyHostKeyViaDaemon(hostname, remote, key, opts.DaemonAddr); err == nil {
return nil
}
if err := tryDaemonVerification(hostname, remote, key, opts.DaemonAddr); err == nil {
return nil
}
// Fallback to known_hosts files
var knownHostsFiles []string
if opts.KnownHostsFile != "" {
knownHostsFiles = append(knownHostsFiles, opts.KnownHostsFile)
} else {
knownHostsFiles = getKnownHostsFiles()
}
var hostKeyCallbacks []ssh.HostKeyCallback
for _, file := range knownHostsFiles {
if callback, err := knownhosts.New(file); err == nil {
hostKeyCallbacks = append(hostKeyCallbacks, callback)
}
}
// Try each known_hosts callback
for _, callback := range hostKeyCallbacks {
if err := callback(hostname, remote, key); err == nil {
return nil
}
}
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
return tryKnownHostsVerification(hostname, remote, key, opts.KnownHostsFile)
}, nil
}
func tryDaemonVerification(hostname string, remote net.Addr, key ssh.PublicKey, daemonAddr string) error {
if daemonAddr == "" {
return fmt.Errorf("no daemon address")
}
return verifyHostKeyViaDaemon(hostname, remote, key, daemonAddr)
}
func tryKnownHostsVerification(hostname string, remote net.Addr, key ssh.PublicKey, knownHostsFile string) error {
knownHostsFiles := getKnownHostsFilesList(knownHostsFile)
hostKeyCallbacks := buildHostKeyCallbacks(knownHostsFiles)
for _, callback := range hostKeyCallbacks {
if err := callback(hostname, remote, key); err == nil {
return nil
}
}
return fmt.Errorf("host key verification failed: key not found in NetBird daemon or any known_hosts file")
}
func getKnownHostsFilesList(knownHostsFile string) []string {
if knownHostsFile != "" {
return []string{knownHostsFile}
}
return getKnownHostsFiles()
}
func buildHostKeyCallbacks(knownHostsFiles []string) []ssh.HostKeyCallback {
var hostKeyCallbacks []ssh.HostKeyCallback
for _, file := range knownHostsFiles {
if callback, err := knownhosts.New(file); err == nil {
hostKeyCallbacks = append(hostKeyCallbacks, callback)
}
}
return hostKeyCallbacks
}
// createSSHKeyAuth creates SSH key authentication from a private key file
func createSSHKeyAuth(keyFile string) (ssh.AuthMethod, error) {
keyData, err := os.ReadFile(keyFile)