mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-16 07:16:38 +00:00
[client] Add non default socket file discovery (#5425)
- Automatic Unix daemon address discovery: if the default socket is missing, the client can find and use a single available socket. - Client startup now resolves daemon addresses more robustly while preserving non-Unix behavior.
This commit is contained in:
@@ -22,6 +22,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -80,6 +81,15 @@ var (
|
|||||||
Short: "",
|
Short: "",
|
||||||
Long: "",
|
Long: "",
|
||||||
SilenceUsage: true,
|
SilenceUsage: true,
|
||||||
|
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(cmd.Root())
|
||||||
|
|
||||||
|
// Don't resolve for service commands — they create the socket, not connect to it.
|
||||||
|
if !isServiceCmd(cmd) {
|
||||||
|
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -386,7 +396,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||||
@@ -399,3 +408,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
|||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isServiceCmd returns true if cmd is the "service" command or a child of it.
|
||||||
|
func isServiceCmd(cmd *cobra.Command) bool {
|
||||||
|
for c := cmd; c != nil; c = c.Parent() {
|
||||||
|
if c.Name() == "service" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
60
client/internal/daemonaddr/resolve.go
Normal file
60
client/internal/daemonaddr/resolve.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
//go:build !windows && !ios && !android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var scanDir = "/var/run/netbird"
|
||||||
|
|
||||||
|
// setScanDir overrides the scan directory (used by tests).
|
||||||
|
func setScanDir(dir string) {
|
||||||
|
scanDir = dir
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
|
||||||
|
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
|
||||||
|
// mismatch between the netbird@.service template (which places the socket under
|
||||||
|
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
|
||||||
|
func ResolveUnixDaemonAddr(addr string) string {
|
||||||
|
if !strings.HasPrefix(addr, "unix://") {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
sockPath := strings.TrimPrefix(addr, "unix://")
|
||||||
|
if _, err := os.Stat(sockPath); err == nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(scanDir)
|
||||||
|
if err != nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var found []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(e.Name(), ".sock") {
|
||||||
|
found = append(found, filepath.Join(scanDir, e.Name()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch len(found) {
|
||||||
|
case 1:
|
||||||
|
resolved := "unix://" + found[0]
|
||||||
|
log.Infof("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||||
|
return resolved
|
||||||
|
case 0:
|
||||||
|
return addr
|
||||||
|
default:
|
||||||
|
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
}
|
||||||
8
client/internal/daemonaddr/resolve_stub.go
Normal file
8
client/internal/daemonaddr/resolve_stub.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build windows || ios || android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
|
||||||
|
func ResolveUnixDaemonAddr(addr string) string {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
121
client/internal/daemonaddr/resolve_test.go
Normal file
121
client/internal/daemonaddr/resolve_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
//go:build !windows && !ios && !android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSockFile creates a regular file with a .sock extension.
|
||||||
|
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
|
||||||
|
// sufficient and avoids Unix socket path-length limits on macOS.
|
||||||
|
func createSockFile(t *testing.T, path string) {
|
||||||
|
t.Helper()
|
||||||
|
if err := os.WriteFile(path, nil, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to create test sock file at %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
sock := filepath.Join(tmp, "netbird.sock")
|
||||||
|
createSockFile(t, sock)
|
||||||
|
|
||||||
|
addr := "unix://" + sock
|
||||||
|
got := ResolveUnixDaemonAddr(addr)
|
||||||
|
if got != addr {
|
||||||
|
t.Errorf("expected %s, got %s", addr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
// Default socket does not exist
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
// Create a scan dir with one socket
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
instanceSock := filepath.Join(sd, "main.sock")
|
||||||
|
createSockFile(t, instanceSock)
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
expected := "unix://" + instanceSock
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
createSockFile(t, filepath.Join(sd, "main.sock"))
|
||||||
|
createSockFile(t, filepath.Join(sd, "other.sock"))
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
|
||||||
|
addr := "tcp://127.0.0.1:41731"
|
||||||
|
got := ResolveUnixDaemonAddr(addr)
|
||||||
|
if got != addr {
|
||||||
|
t.Errorf("expected %s, got %s", addr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(filepath.Join(tmp, "nonexistent"))
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||||
@@ -268,7 +269,7 @@ func getDefaultDaemonAddr() string {
|
|||||||
if runtime.GOOS == "windows" {
|
if runtime.GOOS == "windows" {
|
||||||
return DefaultDaemonAddrWindows
|
return DefaultDaemonAddrWindows
|
||||||
}
|
}
|
||||||
return DefaultDaemonAddr
|
return daemonaddr.ResolveUnixDaemonAddr(DefaultDaemonAddr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DialOptions contains options for SSH connections
|
// DialOptions contains options for SSH connections
|
||||||
|
|||||||
Reference in New Issue
Block a user