mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-28 03:29:55 +00:00
Compare commits
1 Commits
feature/an
...
coderabbit
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
cb1eaf9e0d |
2
Makefile
2
Makefile
@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
|
||||
$(GOLANGCI_LINT):
|
||||
@echo "Installing golangci-lint..."
|
||||
@mkdir -p ./bin
|
||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
|
||||
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
|
||||
|
||||
# Lint only changed files (fast, for pre-push)
|
||||
lint: $(GOLANGCI_LINT)
|
||||
|
||||
@@ -8,7 +8,6 @@ import (
|
||||
"os"
|
||||
"slices"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
@@ -16,7 +15,6 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
"github.com/netbirdio/netbird/client/internal/listener"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
@@ -28,7 +26,6 @@ import (
|
||||
"github.com/netbirdio/netbird/formatter"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
types "github.com/netbirdio/netbird/upload-server/types"
|
||||
)
|
||||
|
||||
// ConnectionListener export internal Listener for mobile
|
||||
@@ -71,30 +68,7 @@ type Client struct {
|
||||
uiVersion string
|
||||
networkChangeListener listener.NetworkChangeListener
|
||||
|
||||
stateMu sync.RWMutex
|
||||
connectClient *internal.ConnectClient
|
||||
config *profilemanager.Config
|
||||
cacheDir string
|
||||
}
|
||||
|
||||
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
|
||||
c.stateMu.Lock()
|
||||
defer c.stateMu.Unlock()
|
||||
c.config = cfg
|
||||
c.cacheDir = cacheDir
|
||||
c.connectClient = cc
|
||||
}
|
||||
|
||||
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
|
||||
c.stateMu.RLock()
|
||||
defer c.stateMu.RUnlock()
|
||||
return c.config, c.cacheDir, c.connectClient
|
||||
}
|
||||
|
||||
func (c *Client) getConnectClient() *internal.ConnectClient {
|
||||
c.stateMu.RLock()
|
||||
defer c.stateMu.RUnlock()
|
||||
return c.connectClient
|
||||
}
|
||||
|
||||
// NewClient instantiate a new Client
|
||||
@@ -119,7 +93,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
|
||||
cfgFile := platformFiles.ConfigurationFilePath()
|
||||
stateFile := platformFiles.StateFilePath()
|
||||
cacheDir := platformFiles.CacheDir()
|
||||
|
||||
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
|
||||
|
||||
@@ -151,9 +124,8 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
}
|
||||
|
||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||
@@ -163,7 +135,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
cfgFile := platformFiles.ConfigurationFilePath()
|
||||
stateFile := platformFiles.StateFilePath()
|
||||
cacheDir := platformFiles.CacheDir()
|
||||
|
||||
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
|
||||
|
||||
@@ -186,9 +157,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
c.setState(cfg, cacheDir, connectClient)
|
||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
}
|
||||
|
||||
// Stop the internal client and free the resources
|
||||
@@ -203,12 +173,11 @@ func (c *Client) Stop() {
|
||||
}
|
||||
|
||||
func (c *Client) RenewTun(fd int) error {
|
||||
cc := c.getConnectClient()
|
||||
if cc == nil {
|
||||
if c.connectClient == nil {
|
||||
return fmt.Errorf("engine not running")
|
||||
}
|
||||
|
||||
e := cc.Engine()
|
||||
e := c.connectClient.Engine()
|
||||
if e == nil {
|
||||
return fmt.Errorf("engine not initialized")
|
||||
}
|
||||
@@ -216,73 +185,6 @@ func (c *Client) RenewTun(fd int) error {
|
||||
return e.RenewTun(fd)
|
||||
}
|
||||
|
||||
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
|
||||
// It works both with and without a running engine.
|
||||
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
|
||||
cfg, cacheDir, cc := c.stateSnapshot()
|
||||
|
||||
// If the engine hasn't been started, load config from disk
|
||||
if cfg == nil {
|
||||
var err error
|
||||
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
|
||||
ConfigPath: platformFiles.ConfigurationFilePath(),
|
||||
})
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("load config: %w", err)
|
||||
}
|
||||
cacheDir = platformFiles.CacheDir()
|
||||
}
|
||||
|
||||
deps := debug.GeneratorDependencies{
|
||||
InternalConfig: cfg,
|
||||
StatusRecorder: c.recorder,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
|
||||
if cc != nil {
|
||||
resp, err := cc.GetLatestSyncResponse()
|
||||
if err != nil {
|
||||
log.Warnf("get latest sync response: %v", err)
|
||||
}
|
||||
deps.SyncResponse = resp
|
||||
|
||||
if e := cc.Engine(); e != nil {
|
||||
if cm := e.GetClientMetrics(); cm != nil {
|
||||
deps.ClientMetrics = cm
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
bundleGenerator := debug.NewBundleGenerator(
|
||||
deps,
|
||||
debug.BundleConfig{
|
||||
Anonymize: anonymize,
|
||||
IncludeSystemInfo: true,
|
||||
},
|
||||
)
|
||||
|
||||
path, err := bundleGenerator.Generate()
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("generate debug bundle: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := os.Remove(path); err != nil {
|
||||
log.Errorf("failed to remove debug bundle file: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("upload debug bundle: %w", err)
|
||||
}
|
||||
|
||||
log.Infof("debug bundle uploaded with key %s", key)
|
||||
return key, nil
|
||||
}
|
||||
|
||||
// SetTraceLogLevel configure the logger to trace level
|
||||
func (c *Client) SetTraceLogLevel() {
|
||||
log.SetLevel(log.TraceLevel)
|
||||
@@ -312,13 +214,12 @@ func (c *Client) PeersList() *PeerInfoArray {
|
||||
}
|
||||
|
||||
func (c *Client) Networks() *NetworkArray {
|
||||
cc := c.getConnectClient()
|
||||
if cc == nil {
|
||||
if c.connectClient == nil {
|
||||
log.Error("not connected")
|
||||
return nil
|
||||
}
|
||||
|
||||
engine := cc.Engine()
|
||||
engine := c.connectClient.Engine()
|
||||
if engine == nil {
|
||||
log.Error("could not get engine")
|
||||
return nil
|
||||
@@ -399,7 +300,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
|
||||
}
|
||||
|
||||
func (c *Client) getRouteManager() (routemanager.Manager, error) {
|
||||
client := c.getConnectClient()
|
||||
client := c.connectClient
|
||||
if client == nil {
|
||||
return nil, fmt.Errorf("not connected")
|
||||
}
|
||||
|
||||
@@ -7,5 +7,4 @@ package android
|
||||
type PlatformFiles interface {
|
||||
ConfigurationFilePath() string
|
||||
StateFilePath() string
|
||||
CacheDir() string
|
||||
}
|
||||
|
||||
@@ -1,434 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package android
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"net"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
gossh "golang.org/x/crypto/ssh"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
nbssh "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/ssh/detection"
|
||||
)
|
||||
|
||||
const (
|
||||
sshDialTimeout = 30 * time.Second
|
||||
sshDetectionTimeout = 5 * time.Second
|
||||
)
|
||||
|
||||
// SSHTerminalListener receives SSH session events. It is implemented in Java.
|
||||
//
|
||||
// All callbacks are invoked from goroutines and may run concurrently with each
|
||||
// other; the implementation must be safe to call from any thread.
|
||||
type SSHTerminalListener interface {
|
||||
OnConnected()
|
||||
OnData(data []byte)
|
||||
OnClose(reason string)
|
||||
OnError(message string)
|
||||
}
|
||||
|
||||
// SSHClient is a NetBird-aware SSH client exposed to Java via gomobile.
|
||||
//
|
||||
// It dials through the running NetBird tunnel and runs a standard SSH session
|
||||
// on top with PTY enabled. Host-key verification uses the NetBird-provided
|
||||
// peer SSH host keys, identical to the desktop client.
|
||||
type SSHClient struct {
|
||||
nb *Client
|
||||
mu sync.Mutex
|
||||
listener SSHTerminalListener
|
||||
urlOpener URLOpener
|
||||
|
||||
sshClient *gossh.Client
|
||||
session *gossh.Session
|
||||
stdin io.WriteCloser
|
||||
closed bool
|
||||
}
|
||||
|
||||
// NewSSHClient creates a new SSH client bound to the running NetBird Client.
|
||||
func NewSSHClient(c *Client) *SSHClient {
|
||||
return &SSHClient{nb: c}
|
||||
}
|
||||
|
||||
// SetListener registers the Java listener. Must be called before Connect to
|
||||
// receive any events.
|
||||
func (s *SSHClient) SetListener(l SSHTerminalListener) {
|
||||
s.mu.Lock()
|
||||
s.listener = l
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// SetURLOpener registers the Java URL opener used to display the device-code
|
||||
// authorization page in a Custom Tabs window when the target peer requires
|
||||
// JWT authentication. Must be set before Connect to be effective.
|
||||
func (s *SSHClient) SetURLOpener(opener URLOpener) {
|
||||
s.mu.Lock()
|
||||
s.urlOpener = opener
|
||||
s.mu.Unlock()
|
||||
}
|
||||
|
||||
// Connect dials the SSH server through the NetBird tunnel and performs the
|
||||
// SSH handshake. It auto-detects the server type via SSH banner inspection
|
||||
// and selects the appropriate authentication path:
|
||||
//
|
||||
// - NetBird-SSH server requiring JWT: launches the OAuth 2.0 device-code
|
||||
// flow, opens the verification URL through the registered URLOpener, and
|
||||
// uses the resulting token as the SSH password. Host-key verification
|
||||
// uses the NetBird peer registry.
|
||||
// - NetBird-SSH server without JWT: authenticates with the NetBird SSH
|
||||
// private key. Host-key verification uses the NetBird peer registry.
|
||||
// - Regular SSH server (e.g. OpenSSH): authenticates with the NetBird key
|
||||
// first (so a user-installed NetBird public key works), then falls back
|
||||
// to the supplied password if non-empty. Host-key verification is
|
||||
// disabled (TOFU pending).
|
||||
//
|
||||
// The password parameter is only consulted for regular SSH servers.
|
||||
func (s *SSHClient) Connect(host string, port int, user, password string) error {
|
||||
cfg, _, cc := s.nb.stateSnapshot()
|
||||
if cc == nil {
|
||||
return errors.New("netbird client not running")
|
||||
}
|
||||
if cfg == nil {
|
||||
return errors.New("netbird config not loaded")
|
||||
}
|
||||
engine := cc.Engine()
|
||||
if engine == nil {
|
||||
return errors.New("netbird engine not available")
|
||||
}
|
||||
|
||||
serverType := detectServerType(host, port)
|
||||
log.Infof("SSH server type for %s:%d: %s", host, port, serverType)
|
||||
|
||||
authMethods, hostKeyCallback, err := s.buildAuth(cfg, engine, serverType, password)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
clientConfig := &gossh.ClientConfig{
|
||||
User: user,
|
||||
Auth: authMethods,
|
||||
HostKeyCallback: hostKeyCallback,
|
||||
Timeout: sshDialTimeout,
|
||||
}
|
||||
return s.dialAndHandshake(host, port, clientConfig)
|
||||
}
|
||||
|
||||
// StartSession requests a PTY and starts an interactive shell. Output from
|
||||
// the session is forwarded to the listener via OnData.
|
||||
func (s *SSHClient) StartSession(cols, rows int) error {
|
||||
log.Debugf("SSH: starting session %dx%d", cols, rows)
|
||||
s.mu.Lock()
|
||||
sshClient := s.sshClient
|
||||
s.mu.Unlock()
|
||||
|
||||
if sshClient == nil {
|
||||
return errors.New("ssh client not connected")
|
||||
}
|
||||
|
||||
session, err := sshClient.NewSession()
|
||||
if err != nil {
|
||||
return fmt.Errorf("new session: %w", err)
|
||||
}
|
||||
|
||||
modes := gossh.TerminalModes{
|
||||
gossh.ECHO: 1,
|
||||
gossh.TTY_OP_ISPEED: 14400,
|
||||
gossh.TTY_OP_OSPEED: 14400,
|
||||
gossh.VINTR: 3,
|
||||
gossh.VQUIT: 28,
|
||||
gossh.VERASE: 127,
|
||||
}
|
||||
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
|
||||
closeQuiet(session, "session after pty error")
|
||||
return fmt.Errorf("request pty: %w", err)
|
||||
}
|
||||
|
||||
stdin, err := session.StdinPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stdin error")
|
||||
return fmt.Errorf("stdin pipe: %w", err)
|
||||
}
|
||||
stdout, err := session.StdoutPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stdout error")
|
||||
return fmt.Errorf("stdout pipe: %w", err)
|
||||
}
|
||||
stderr, err := session.StderrPipe()
|
||||
if err != nil {
|
||||
closeQuiet(session, "session after stderr error")
|
||||
return fmt.Errorf("stderr pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := session.Shell(); err != nil {
|
||||
closeQuiet(session, "session after shell error")
|
||||
return fmt.Errorf("start shell: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.session = session
|
||||
s.stdin = stdin
|
||||
s.mu.Unlock()
|
||||
|
||||
go s.readLoop(stdout, "stdout")
|
||||
go s.readLoop(stderr, "stderr")
|
||||
log.Debug("SSH: session started, shell running")
|
||||
return nil
|
||||
}
|
||||
|
||||
// Write sends data to the SSH session stdin.
|
||||
func (s *SSHClient) Write(data []byte) error {
|
||||
s.mu.Lock()
|
||||
stdin := s.stdin
|
||||
s.mu.Unlock()
|
||||
if stdin == nil {
|
||||
return errors.New("ssh session not started")
|
||||
}
|
||||
if _, err := stdin.Write(data); err != nil {
|
||||
return fmt.Errorf("write stdin: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Resize updates the PTY window size.
|
||||
func (s *SSHClient) Resize(cols, rows int) error {
|
||||
s.mu.Lock()
|
||||
session := s.session
|
||||
s.mu.Unlock()
|
||||
if session == nil {
|
||||
return errors.New("ssh session not started")
|
||||
}
|
||||
return session.WindowChange(rows, cols)
|
||||
}
|
||||
|
||||
// Close terminates the SSH session and underlying connection. Safe to call
|
||||
// multiple times.
|
||||
func (s *SSHClient) Close() error {
|
||||
s.mu.Lock()
|
||||
sshClient := s.sshClient
|
||||
session := s.session
|
||||
stdin := s.stdin
|
||||
s.sshClient = nil
|
||||
s.session = nil
|
||||
s.stdin = nil
|
||||
s.mu.Unlock()
|
||||
|
||||
if stdin != nil {
|
||||
if err := stdin.Close(); err != nil {
|
||||
log.Debugf("ssh: stdin close: %v", err)
|
||||
}
|
||||
}
|
||||
if session != nil {
|
||||
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh: session close: %v", err)
|
||||
}
|
||||
}
|
||||
var firstErr error
|
||||
if sshClient != nil {
|
||||
if err := sshClient.Close(); err != nil {
|
||||
firstErr = err
|
||||
}
|
||||
}
|
||||
s.notifyClose("closed by client")
|
||||
return firstErr
|
||||
}
|
||||
|
||||
func (s *SSHClient) buildAuth(cfg *profilemanager.Config, engine *internal.Engine,
|
||||
serverType detection.ServerType, password string) ([]gossh.AuthMethod, gossh.HostKeyCallback, error) {
|
||||
|
||||
switch serverType {
|
||||
case detection.ServerTypeNetBirdJWT:
|
||||
token, err := s.requestJWTToken(cfg)
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("jwt: %w", err)
|
||||
}
|
||||
auths := []gossh.AuthMethod{gossh.Password(token)}
|
||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
||||
|
||||
case detection.ServerTypeNetBirdNoJWT:
|
||||
if cfg.SSHKey == "" {
|
||||
return nil, nil, errors.New("no NetBird SSH key available")
|
||||
}
|
||||
signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey))
|
||||
if err != nil {
|
||||
return nil, nil, fmt.Errorf("parse netbird ssh key: %w", err)
|
||||
}
|
||||
auths := []gossh.AuthMethod{gossh.PublicKeys(signer)}
|
||||
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
|
||||
|
||||
default: // regular SSH
|
||||
var auths []gossh.AuthMethod
|
||||
if cfg.SSHKey != "" {
|
||||
if signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey)); err == nil {
|
||||
auths = append(auths, gossh.PublicKeys(signer))
|
||||
} else {
|
||||
log.Debugf("ssh: parse netbird key for regular auth: %v", err)
|
||||
}
|
||||
}
|
||||
if password != "" {
|
||||
pw := password
|
||||
auths = append(auths, gossh.Password(pw))
|
||||
auths = append(auths, gossh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) {
|
||||
answers := make([]string, len(questions))
|
||||
for i := range questions {
|
||||
answers[i] = pw
|
||||
}
|
||||
return answers, nil
|
||||
}))
|
||||
}
|
||||
if len(auths) == 0 {
|
||||
return nil, nil, errors.New("no auth method available: provide a password or configure NetBird SSH key")
|
||||
}
|
||||
return auths, gossh.InsecureIgnoreHostKey(), nil // nolint:gosec // TOFU not yet implemented
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHClient) requestJWTToken(cfg *profilemanager.Config) (string, error) {
|
||||
s.mu.Lock()
|
||||
urlOpener := s.urlOpener
|
||||
s.mu.Unlock()
|
||||
if urlOpener == nil {
|
||||
return "", errors.New("URL opener not configured for JWT auth")
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
|
||||
defer cancel()
|
||||
|
||||
flow, err := auth.NewOAuthFlow(ctx, cfg, false, true, profilemanager.GetLoginHint())
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create oauth flow: %w", err)
|
||||
}
|
||||
|
||||
flowInfo, err := flow.RequestAuthInfo(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("request auth info: %w", err)
|
||||
}
|
||||
|
||||
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
||||
|
||||
tokenInfo, err := flow.WaitToken(ctx, flowInfo)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("wait for token: %w", err)
|
||||
}
|
||||
|
||||
token := tokenInfo.GetTokenToUse()
|
||||
if token == "" {
|
||||
return "", errors.New("empty token returned by IdP")
|
||||
}
|
||||
return token, nil
|
||||
}
|
||||
|
||||
func (s *SSHClient) dialAndHandshake(host string, port int, clientConfig *gossh.ClientConfig) error {
|
||||
addr := net.JoinHostPort(host, strconv.Itoa(port))
|
||||
log.Infof("SSH: connecting to %s as %s", addr, clientConfig.User)
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
|
||||
defer cancel()
|
||||
|
||||
var dialer net.Dialer
|
||||
conn, err := dialer.DialContext(ctx, "tcp", addr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("dial %s: %w", addr, err)
|
||||
}
|
||||
|
||||
sshConn, chans, reqs, err := gossh.NewClientConn(conn, addr, clientConfig)
|
||||
if err != nil {
|
||||
if cerr := conn.Close(); cerr != nil {
|
||||
log.Debugf("ssh: close after handshake error: %v", cerr)
|
||||
}
|
||||
return fmt.Errorf("ssh handshake: %w", err)
|
||||
}
|
||||
|
||||
s.mu.Lock()
|
||||
s.sshClient = gossh.NewClient(sshConn, chans, reqs)
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
|
||||
log.Infof("SSH: connected to %s", addr)
|
||||
if listener != nil {
|
||||
listener.OnConnected()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SSHClient) readLoop(r io.Reader, name string) {
|
||||
buf := make([]byte, 4096)
|
||||
for {
|
||||
n, err := r.Read(buf)
|
||||
if n > 0 {
|
||||
s.mu.Lock()
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
if listener != nil {
|
||||
chunk := make([]byte, n)
|
||||
copy(chunk, buf[:n])
|
||||
listener.OnData(chunk)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
if !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh %s read: %v", name, err)
|
||||
}
|
||||
s.notifyClose(err.Error())
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SSHClient) notifyClose(reason string) {
|
||||
s.mu.Lock()
|
||||
if s.closed {
|
||||
s.mu.Unlock()
|
||||
return
|
||||
}
|
||||
s.closed = true
|
||||
listener := s.listener
|
||||
s.mu.Unlock()
|
||||
if listener != nil {
|
||||
listener.OnClose(reason)
|
||||
}
|
||||
}
|
||||
|
||||
// engineHostKeyVerifier adapts *internal.Engine to nbssh.HostKeyVerifier.
|
||||
type engineHostKeyVerifier struct {
|
||||
engine *internal.Engine
|
||||
}
|
||||
|
||||
func (v *engineHostKeyVerifier) VerifySSHHostKey(peerAddress string, presented []byte) error {
|
||||
storedKey, found := v.engine.GetPeerSSHKey(peerAddress)
|
||||
if !found {
|
||||
return nbssh.ErrPeerNotFound
|
||||
}
|
||||
return nbssh.VerifyHostKey(storedKey, presented, peerAddress)
|
||||
}
|
||||
|
||||
func closeQuiet(c io.Closer, label string) {
|
||||
if c == nil {
|
||||
return
|
||||
}
|
||||
if err := c.Close(); err != nil && !errors.Is(err, io.EOF) {
|
||||
log.Debugf("ssh: close %s: %v", label, err)
|
||||
}
|
||||
}
|
||||
|
||||
func detectServerType(host string, port int) detection.ServerType {
|
||||
ctx, cancel := context.WithTimeout(context.Background(), sshDetectionTimeout)
|
||||
defer cancel()
|
||||
|
||||
dialer := &net.Dialer{}
|
||||
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
|
||||
if err != nil {
|
||||
log.Debugf("ssh: server detection for %s:%d failed: %v (assuming regular SSH)", host, port, err)
|
||||
return detection.ServerTypeRegular
|
||||
}
|
||||
return serverType
|
||||
}
|
||||
@@ -217,6 +217,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||
// Close closes the tunnel interface
|
||||
func (w *WGIface) Close() error {
|
||||
w.mu.Lock()
|
||||
defer w.mu.Unlock()
|
||||
|
||||
var result *multierror.Error
|
||||
|
||||
@@ -224,15 +225,7 @@ func (w *WGIface) Close() error {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
|
||||
}
|
||||
|
||||
// Release w.mu before calling w.tun.Close(): the underlying
|
||||
// wireguard-go device.Close() waits for its send/receive goroutines
|
||||
// to drain. Some of those goroutines re-enter WGIface methods that
|
||||
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
|
||||
// holding the mutex here would deadlock the shutdown path.
|
||||
tun := w.tun
|
||||
w.mu.Unlock()
|
||||
|
||||
if err := tun.Close(); err != nil {
|
||||
if err := w.tun.Close(); err != nil {
|
||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||
}
|
||||
|
||||
|
||||
@@ -1,113 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package iface
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||
"golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||
)
|
||||
|
||||
// fakeTunDevice implements WGTunDevice and lets the test control when
|
||||
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
|
||||
// until its goroutines drain. Some of those goroutines (e.g. the packet
|
||||
// filter DNS hook in client/internal/dns) call back into WGIface, so if
|
||||
// WGIface.Close() held w.mu across tun.Close() the shutdown would
|
||||
// deadlock.
|
||||
type fakeTunDevice struct {
|
||||
closeStarted chan struct{}
|
||||
unblockClose chan struct{}
|
||||
}
|
||||
|
||||
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
|
||||
return nil, errors.New("not implemented")
|
||||
}
|
||||
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
|
||||
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
|
||||
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
|
||||
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
|
||||
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
|
||||
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
|
||||
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
|
||||
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
|
||||
|
||||
func (f *fakeTunDevice) Close() error {
|
||||
close(f.closeStarted)
|
||||
<-f.unblockClose
|
||||
return nil
|
||||
}
|
||||
|
||||
type fakeProxyFactory struct{}
|
||||
|
||||
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
|
||||
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
|
||||
func (fakeProxyFactory) Free() error { return nil }
|
||||
|
||||
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
|
||||
// that surfaces as a macOS test-timeout in
|
||||
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
|
||||
// while waiting for the wireguard-go device goroutines to finish, and
|
||||
// one of those goroutines (the DNS filter hook) calls back into
|
||||
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
|
||||
// the lock before tun.Close() returns control.
|
||||
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
|
||||
tun := &fakeTunDevice{
|
||||
closeStarted: make(chan struct{}),
|
||||
unblockClose: make(chan struct{}),
|
||||
}
|
||||
w := &WGIface{
|
||||
tun: tun,
|
||||
wgProxyFactory: fakeProxyFactory{},
|
||||
}
|
||||
|
||||
closeDone := make(chan error, 1)
|
||||
go func() {
|
||||
closeDone <- w.Close()
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-tun.closeStarted:
|
||||
case <-time.After(2 * time.Second):
|
||||
close(tun.unblockClose)
|
||||
t.Fatal("tun.Close() was never invoked")
|
||||
}
|
||||
|
||||
// Simulate the WireGuard read goroutine calling back into WGIface
|
||||
// via the packet filter's DNS hook. If Close() still held w.mu
|
||||
// during tun.Close(), this would block until the test timeout.
|
||||
getDeviceDone := make(chan struct{})
|
||||
var wg sync.WaitGroup
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
_ = w.GetDevice()
|
||||
close(getDeviceDone)
|
||||
}()
|
||||
|
||||
select {
|
||||
case <-getDeviceDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
close(tun.unblockClose)
|
||||
wg.Wait()
|
||||
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
|
||||
}
|
||||
|
||||
close(tun.unblockClose)
|
||||
select {
|
||||
case <-closeDone:
|
||||
case <-time.After(2 * time.Second):
|
||||
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
|
||||
}
|
||||
}
|
||||
@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
}
|
||||
|
||||
if u.address.Network.Contains(a) {
|
||||
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||
}
|
||||
|
||||
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
|
||||
u.addrCache.Store(addr.String(), isRouted)
|
||||
if isRouted {
|
||||
// Extra log, as the error only shows up with ICE logging enabled
|
||||
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -94,7 +94,6 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
dnsAddresses []netip.AddrPort,
|
||||
dnsReadyListener dns.ReadyListener,
|
||||
stateFilePath string,
|
||||
cacheDir string,
|
||||
) error {
|
||||
// in case of non Android os these variables will be nil
|
||||
mobileDependency := MobileDependency{
|
||||
@@ -104,7 +103,6 @@ func (c *ConnectClient) RunOnAndroid(
|
||||
HostDNSAddresses: dnsAddresses,
|
||||
DnsReadyListener: dnsReadyListener,
|
||||
StateFilePath: stateFilePath,
|
||||
TempDir: cacheDir,
|
||||
}
|
||||
return c.run(mobileDependency, nil, "")
|
||||
}
|
||||
@@ -340,7 +338,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
log.Error(err)
|
||||
return wrapErr(err)
|
||||
}
|
||||
engineConfig.TempDir = mobileDependency.TempDir
|
||||
|
||||
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
|
||||
c.statusRecorder.SetRelayMgr(relayManager)
|
||||
|
||||
@@ -16,6 +16,7 @@ import (
|
||||
"path/filepath"
|
||||
"runtime"
|
||||
"runtime/pprof"
|
||||
"slices"
|
||||
"sort"
|
||||
"strings"
|
||||
"time"
|
||||
@@ -30,6 +31,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
const readmeContent = `Netbird debug bundle
|
||||
@@ -232,7 +234,6 @@ type BundleGenerator struct {
|
||||
statusRecorder *peer.Status
|
||||
syncResponse *mgmProto.SyncResponse
|
||||
logPath string
|
||||
tempDir string
|
||||
cpuProfile []byte
|
||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
clientMetrics MetricsExporter
|
||||
@@ -255,7 +256,6 @@ type GeneratorDependencies struct {
|
||||
StatusRecorder *peer.Status
|
||||
SyncResponse *mgmProto.SyncResponse
|
||||
LogPath string
|
||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||
CPUProfile []byte
|
||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||
ClientMetrics MetricsExporter
|
||||
@@ -275,7 +275,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
statusRecorder: deps.StatusRecorder,
|
||||
syncResponse: deps.SyncResponse,
|
||||
logPath: deps.LogPath,
|
||||
tempDir: deps.TempDir,
|
||||
cpuProfile: deps.CPUProfile,
|
||||
refreshStatus: deps.RefreshStatus,
|
||||
clientMetrics: deps.ClientMetrics,
|
||||
@@ -288,7 +287,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
||||
|
||||
// Generate creates a debug bundle and returns the location.
|
||||
func (g *BundleGenerator) Generate() (resp string, err error) {
|
||||
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
|
||||
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("create zip file: %w", err)
|
||||
}
|
||||
@@ -374,8 +373,15 @@ func (g *BundleGenerator) createArchive() error {
|
||||
log.Errorf("failed to add wg show output: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addPlatformLog(); err != nil {
|
||||
log.Errorf("failed to add logs to debug bundle: %v", err)
|
||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||
if err := g.trySystemdLogFallback(); err != nil {
|
||||
log.Errorf("failed to add systemd logs as fallback: %v", err)
|
||||
}
|
||||
}
|
||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||
log.Errorf("failed to add systemd logs: %v", err)
|
||||
}
|
||||
|
||||
if err := g.addUpdateLogs(); err != nil {
|
||||
|
||||
@@ -1,41 +0,0 @@
|
||||
//go:build android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"os/exec"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
func (g *BundleGenerator) addPlatformLog() error {
|
||||
cmd := exec.Command("/system/bin/logcat", "-d")
|
||||
stdout, err := cmd.StdoutPipe()
|
||||
if err != nil {
|
||||
return fmt.Errorf("logcat stdout pipe: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Start(); err != nil {
|
||||
return fmt.Errorf("start logcat: %w", err)
|
||||
}
|
||||
|
||||
var logReader io.Reader = stdout
|
||||
if g.anonymize {
|
||||
var pw *io.PipeWriter
|
||||
logReader, pw = io.Pipe()
|
||||
go anonymizeLog(stdout, pw, g.anonymizer)
|
||||
}
|
||||
|
||||
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
|
||||
return fmt.Errorf("add logcat to zip: %w", err)
|
||||
}
|
||||
|
||||
if err := cmd.Wait(); err != nil {
|
||||
return fmt.Errorf("wait logcat: %w", err)
|
||||
}
|
||||
|
||||
log.Debug("added logcat output to debug bundle")
|
||||
return nil
|
||||
}
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build !android
|
||||
|
||||
package debug
|
||||
|
||||
import (
|
||||
"slices"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
func (g *BundleGenerator) addPlatformLog() error {
|
||||
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
|
||||
if err := g.addLogfile(); err != nil {
|
||||
log.Errorf("failed to add log file to debug bundle: %v", err)
|
||||
if err := g.trySystemdLogFallback(); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if err := g.trySystemdLogFallback(); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@@ -140,7 +140,6 @@ type EngineConfig struct {
|
||||
ProfileConfig *profilemanager.Config
|
||||
|
||||
LogPath string
|
||||
TempDir string
|
||||
}
|
||||
|
||||
// EngineServices holds the external service dependencies required by the Engine.
|
||||
@@ -1096,7 +1095,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
||||
StatusRecorder: e.statusRecorder,
|
||||
SyncResponse: syncResponse,
|
||||
LogPath: e.config.LogPath,
|
||||
TempDir: e.config.TempDir,
|
||||
ClientMetrics: e.clientMetrics,
|
||||
RefreshStatus: func() {
|
||||
e.RunHealthProbes(true)
|
||||
|
||||
@@ -22,8 +22,4 @@ type MobileDependency struct {
|
||||
DnsManager dns.IosDnsManager
|
||||
FileDescriptor int32
|
||||
StateFilePath string
|
||||
|
||||
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
|
||||
// On Android, this should be set to the app's cache directory.
|
||||
TempDir string
|
||||
}
|
||||
|
||||
@@ -1,10 +0,0 @@
|
||||
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
|
||||
|
||||
package systemops
|
||||
|
||||
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
|
||||
// always fall through to the ref-counter exclusion-route path; these stubs
|
||||
// exist only so systemops_unix.go compiles.
|
||||
func (r *SysOps) setupAdvancedRouting() error { return nil }
|
||||
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
|
||||
func (r *SysOps) flushPlatformExtras() error { return nil }
|
||||
@@ -1,241 +0,0 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os"
|
||||
"time"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/net/route"
|
||||
"golang.org/x/sys/unix"
|
||||
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// scopedRouteBudget bounds retries for the scoped default route. Installing or
|
||||
// deleting it matters enough that we're willing to spend longer waiting for the
|
||||
// kernel reply than for per-prefix exclusion routes.
|
||||
const scopedRouteBudget = 5 * time.Second
|
||||
|
||||
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
|
||||
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
|
||||
// resolve gateway'd destinations while the VPN's split default owns the
|
||||
// unscoped table.
|
||||
//
|
||||
// Timing note: this runs during routeManager.Init, which happens before the
|
||||
// VPN interface is created and before any peer routes propagate. The initial
|
||||
// mgmt / signal / relay TCP dials always fire before this runs, so those
|
||||
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
|
||||
// lookup, which at that point correctly picks the physical default. Those
|
||||
// already-established TCP flows keep their originally-selected interface for
|
||||
// their lifetime on Darwin because the kernel caches the egress route
|
||||
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
|
||||
// afterwards does not migrate them since the original en0 default stays in
|
||||
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
|
||||
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
|
||||
func (r *SysOps) setupAdvancedRouting() error {
|
||||
// Drop any previously-cached egress interface before reinstalling. On a
|
||||
// refresh, a family that no longer resolves would otherwise keep the stale
|
||||
// binding, causing new sockets to scope to an interface without a matching
|
||||
// scoped default.
|
||||
nbnet.ClearBoundInterfaces()
|
||||
|
||||
if err := r.flushScopedDefaults(); err != nil {
|
||||
log.Warnf("flush residual scoped defaults: %v", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
installed := 0
|
||||
|
||||
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
|
||||
ok, err := r.installScopedDefaultFor(unspec)
|
||||
if err != nil {
|
||||
merr = multierror.Append(merr, err)
|
||||
continue
|
||||
}
|
||||
if ok {
|
||||
installed++
|
||||
}
|
||||
}
|
||||
|
||||
if installed == 0 && merr != nil {
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
if merr != nil {
|
||||
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// installScopedDefaultFor resolves the physical default nexthop for the given
|
||||
// address family, installs a scoped default via it, and caches the iface for
|
||||
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
|
||||
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
||||
nexthop, err := GetNextHop(unspec)
|
||||
if err != nil {
|
||||
if errors.Is(err, vars.ErrRouteNotFound) {
|
||||
return false, nil
|
||||
}
|
||||
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
|
||||
}
|
||||
if nexthop.Intf == nil {
|
||||
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
||||
}
|
||||
|
||||
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
||||
}
|
||||
|
||||
af := unix.AF_INET
|
||||
if unspec.Is6() {
|
||||
af = unix.AF_INET6
|
||||
}
|
||||
nbnet.SetBoundInterface(af, nexthop.Intf)
|
||||
via := "point-to-point"
|
||||
if nexthop.IP.IsValid() {
|
||||
via = nexthop.IP.String()
|
||||
}
|
||||
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) cleanupAdvancedRouting() error {
|
||||
nbnet.ClearBoundInterfaces()
|
||||
return r.flushScopedDefaults()
|
||||
}
|
||||
|
||||
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
|
||||
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
|
||||
// removed on the next boot regardless of whether a profile is brought up.
|
||||
func (r *SysOps) flushPlatformExtras() error {
|
||||
return r.flushScopedDefaults()
|
||||
}
|
||||
|
||||
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
|
||||
// Safe to call at startup to clear residual entries from a prior session.
|
||||
func (r *SysOps) flushScopedDefaults() error {
|
||||
rib, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
return fmt.Errorf("fetch routing table: %w", err)
|
||||
}
|
||||
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||
if err != nil {
|
||||
return fmt.Errorf("parse routing table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
removed := 0
|
||||
|
||||
for _, msg := range msgs {
|
||||
rtMsg, ok := msg.(*route.RouteMessage)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if rtMsg.Flags&routeProtoFlag == 0 {
|
||||
continue
|
||||
}
|
||||
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
info, err := MsgToRoute(rtMsg)
|
||||
if err != nil {
|
||||
log.Debugf("skip scoped flush: %v", err)
|
||||
continue
|
||||
}
|
||||
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
if err := r.deleteScopedRoute(rtMsg); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
|
||||
info.Dst, rtMsg.Index, err))
|
||||
continue
|
||||
}
|
||||
removed++
|
||||
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
|
||||
}
|
||||
|
||||
if removed > 0 {
|
||||
log.Infof("flushed %d residual scoped default route(s)", removed)
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
}
|
||||
|
||||
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
|
||||
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
|
||||
}
|
||||
|
||||
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
|
||||
// Preserve identifying flags from the stored route (including RTF_GATEWAY
|
||||
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
|
||||
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
|
||||
del := &route.RouteMessage{
|
||||
Type: unix.RTM_DELETE,
|
||||
Flags: rtMsg.Flags & keep,
|
||||
Version: unix.RTM_VERSION,
|
||||
Seq: r.getSeq(),
|
||||
Index: rtMsg.Index,
|
||||
Addrs: rtMsg.Addrs,
|
||||
}
|
||||
return r.writeRouteMessage(del, scopedRouteBudget)
|
||||
}
|
||||
|
||||
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
|
||||
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
|
||||
|
||||
msg := &route.RouteMessage{
|
||||
Type: action,
|
||||
Flags: flags,
|
||||
Version: unix.RTM_VERSION,
|
||||
ID: uintptr(os.Getpid()),
|
||||
Seq: r.getSeq(),
|
||||
Index: nexthop.Intf.Index,
|
||||
}
|
||||
|
||||
const numAddrs = unix.RTAX_NETMASK + 1
|
||||
addrs := make([]route.Addr, numAddrs)
|
||||
|
||||
dst, err := addrToRouteAddr(unspec)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build destination: %w", err)
|
||||
}
|
||||
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
|
||||
if err != nil {
|
||||
return fmt.Errorf("build netmask: %w", err)
|
||||
}
|
||||
addrs[unix.RTAX_DST] = dst
|
||||
addrs[unix.RTAX_NETMASK] = mask
|
||||
|
||||
if nexthop.IP.IsValid() {
|
||||
msg.Flags |= unix.RTF_GATEWAY
|
||||
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
|
||||
if err != nil {
|
||||
return fmt.Errorf("build gateway: %w", err)
|
||||
}
|
||||
addrs[unix.RTAX_GATEWAY] = gw
|
||||
} else {
|
||||
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
|
||||
Index: nexthop.Intf.Index,
|
||||
Name: nexthop.Intf.Name,
|
||||
}
|
||||
}
|
||||
msg.Addrs = addrs
|
||||
|
||||
return r.writeRouteMessage(msg, scopedRouteBudget)
|
||||
}
|
||||
|
||||
func afOf(a netip.Addr) string {
|
||||
if a.Is4() {
|
||||
return "IPv4"
|
||||
}
|
||||
return "IPv6"
|
||||
}
|
||||
125
client/internal/routemanager/systemops/systemops_darwin_test.go
Normal file
125
client/internal/routemanager/systemops/systemops_darwin_test.go
Normal file
@@ -0,0 +1,125 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// TestAfOf verifies that afOf returns the correct string for each address family.
|
||||
func TestAfOf(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
addr netip.Addr
|
||||
want string
|
||||
}{
|
||||
{
|
||||
name: "IPv4 unspecified",
|
||||
addr: netip.IPv4Unspecified(),
|
||||
want: "IPv4",
|
||||
},
|
||||
{
|
||||
name: "IPv4 private",
|
||||
addr: netip.MustParseAddr("10.0.0.1"),
|
||||
want: "IPv4",
|
||||
},
|
||||
{
|
||||
name: "IPv4 loopback",
|
||||
addr: netip.MustParseAddr("127.0.0.1"),
|
||||
want: "IPv4",
|
||||
},
|
||||
{
|
||||
name: "IPv6 unspecified",
|
||||
addr: netip.IPv6Unspecified(),
|
||||
want: "IPv6",
|
||||
},
|
||||
{
|
||||
name: "IPv6 loopback",
|
||||
addr: netip.MustParseAddr("::1"),
|
||||
want: "IPv6",
|
||||
},
|
||||
{
|
||||
name: "IPv6 unicast",
|
||||
addr: netip.MustParseAddr("2001:db8::1"),
|
||||
want: "IPv6",
|
||||
},
|
||||
{
|
||||
name: "IPv6 link-local",
|
||||
addr: netip.MustParseAddr("fe80::1"),
|
||||
want: "IPv6",
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, afOf(tt.addr))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsAddrRouted_AdvancedRoutingBypassesTunnelLookup verifies that when
|
||||
// AdvancedRouting is active, IsAddrRouted immediately returns (false, zero)
|
||||
// regardless of the provided vpn routes, because the WG socket is bound to
|
||||
// the physical interface via IP_BOUND_IF and bypasses the main routing table.
|
||||
func TestIsAddrRouted_AdvancedRoutingBypassesTunnelLookup(t *testing.T) {
|
||||
// On darwin, AdvancedRouting returns true unless overridden.
|
||||
// Ensure we reset the state after the test.
|
||||
t.Setenv("NB_USE_LEGACY_ROUTING", "false")
|
||||
t.Setenv("NB_USE_NETSTACK_MODE", "false")
|
||||
nbnet.Init()
|
||||
|
||||
require.True(t, nbnet.AdvancedRouting(), "test requires advanced routing to be active on darwin")
|
||||
|
||||
vpnRoutes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.1.0/24"),
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
addr netip.Addr
|
||||
}{
|
||||
{"IPv4 in VPN route", netip.MustParseAddr("10.0.0.1")},
|
||||
{"IPv4 in narrow VPN route", netip.MustParseAddr("192.168.1.100")},
|
||||
{"IPv4 default route covered", netip.MustParseAddr("8.8.8.8")},
|
||||
{"IPv6 in VPN route", netip.MustParseAddr("2001:db8::1")},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
routed, prefix := IsAddrRouted(tt.addr, vpnRoutes)
|
||||
assert.False(t, routed, "should not be marked as routed via VPN when advanced routing is active")
|
||||
assert.Equal(t, netip.Prefix{}, prefix, "matched prefix should be zero when advanced routing is active")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsAddrRouted_LegacyModeFallsThroughToTable verifies that when
|
||||
// NB_USE_LEGACY_ROUTING=true disables advanced routing, IsAddrRouted
|
||||
// performs the normal VPN-route vs local-route comparison.
|
||||
func TestIsAddrRouted_LegacyModeFallsThroughToTable(t *testing.T) {
|
||||
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
||||
nbnet.Init()
|
||||
|
||||
require.False(t, nbnet.AdvancedRouting(), "test requires advanced routing to be disabled")
|
||||
|
||||
// Use an address that is very unlikely to exist in the host routing table
|
||||
// as a local route, so the VPN route wins.
|
||||
vpnRoutes := []netip.Prefix{
|
||||
netip.MustParsePrefix("198.51.100.0/24"), // TEST-NET-2 – not in normal routing tables
|
||||
}
|
||||
|
||||
addr := netip.MustParseAddr("198.51.100.1")
|
||||
routed, _ := IsAddrRouted(addr, vpnRoutes)
|
||||
// We cannot assert a specific outcome because it depends on the host's
|
||||
// routing table, but we CAN assert that the call did not panic and returned
|
||||
// a consistent pair.
|
||||
_ = routed
|
||||
}
|
||||
@@ -21,7 +21,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/util"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/net/hooks"
|
||||
)
|
||||
|
||||
@@ -32,6 +31,8 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
|
||||
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
|
||||
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
|
||||
|
||||
var ErrRoutingIsSeparate = errors.New("routing is separate")
|
||||
|
||||
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
@@ -396,16 +397,12 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
|
||||
}
|
||||
|
||||
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
|
||||
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
|
||||
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
|
||||
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
|
||||
if nbnet.AdvancedRouting() {
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
localRoutes, err := GetRoutesFromTable()
|
||||
localRoutes, err := hasSeparateRouting()
|
||||
if err != nil {
|
||||
log.Errorf("Failed to get routes: %v", err)
|
||||
if !errors.Is(err, ErrRoutingIsSeparate) {
|
||||
log.Errorf("Failed to get routes: %v", err)
|
||||
}
|
||||
return false, netip.Prefix{}
|
||||
}
|
||||
|
||||
|
||||
@@ -0,0 +1,140 @@
|
||||
//go:build !android && !ios
|
||||
|
||||
package systemops
|
||||
|
||||
import (
|
||||
"net/netip"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
)
|
||||
|
||||
// withLegacyRouting forcibly puts the nbnet package into legacy (non-advanced)
|
||||
// routing mode for the duration of the test, then restores the previous value.
|
||||
func withLegacyRouting(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
|
||||
t.Setenv("NB_USE_NETSTACK_MODE", "false")
|
||||
nbnet.Init()
|
||||
t.Cleanup(func() {
|
||||
// After the test, re-initialise with no overrides so that subsequent
|
||||
// tests start with a clean state.
|
||||
nbnet.Init()
|
||||
})
|
||||
}
|
||||
|
||||
// withAdvancedRouting attempts to enable advanced routing for the test.
|
||||
// On platforms where Init() cannot produce AdvancedRouting()=true (e.g. Linux
|
||||
// without root), the calling test is skipped.
|
||||
func withAdvancedRouting(t *testing.T) {
|
||||
t.Helper()
|
||||
t.Setenv("NB_USE_LEGACY_ROUTING", "false")
|
||||
t.Setenv("NB_USE_NETSTACK_MODE", "false")
|
||||
nbnet.Init()
|
||||
t.Cleanup(func() {
|
||||
nbnet.Init()
|
||||
})
|
||||
if !nbnet.AdvancedRouting() {
|
||||
t.Skip("advanced routing not available in this environment (need root or darwin)")
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsAddrRouted_AdvancedRoutingShortCircuit verifies that when
|
||||
// AdvancedRouting() returns true, IsAddrRouted immediately returns
|
||||
// (false, zero prefix) regardless of any VPN routes, because the WG socket is
|
||||
// bound directly to the physical interface and bypasses the kernel routing table.
|
||||
func TestIsAddrRouted_AdvancedRoutingShortCircuit(t *testing.T) {
|
||||
withAdvancedRouting(t)
|
||||
|
||||
vpnRoutes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
netip.MustParsePrefix("192.168.0.0/16"),
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
netip.MustParsePrefix("::/0"),
|
||||
}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
addr string
|
||||
}{
|
||||
{"IPv4 matched by 10/8", "10.0.0.1"},
|
||||
{"IPv4 matched by 192.168/16", "192.168.1.1"},
|
||||
{"IPv4 caught by default", "8.8.8.8"},
|
||||
{"IPv6 caught by default", "2001:db8::1"},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
addr := netip.MustParseAddr(tt.addr)
|
||||
routed, prefix := IsAddrRouted(addr, vpnRoutes)
|
||||
assert.False(t, routed, "advanced routing must short-circuit VPN route check")
|
||||
assert.Equal(t, netip.Prefix{}, prefix, "returned prefix must be zero under advanced routing")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestIsAddrRouted_AdvancedRouting_EmptyRoutes verifies the short-circuit with
|
||||
// an empty vpnRoutes slice — the result must still be (false, zero).
|
||||
func TestIsAddrRouted_AdvancedRouting_EmptyRoutes(t *testing.T) {
|
||||
withAdvancedRouting(t)
|
||||
|
||||
routed, prefix := IsAddrRouted(netip.MustParseAddr("10.0.0.1"), nil)
|
||||
assert.False(t, routed)
|
||||
assert.Equal(t, netip.Prefix{}, prefix)
|
||||
}
|
||||
|
||||
// TestIsAddrRouted_LegacyMode_NonVPNAddress verifies that under legacy routing
|
||||
// an address that is not covered by any VPN prefix returns false.
|
||||
func TestIsAddrRouted_LegacyMode_NonVPNAddress(t *testing.T) {
|
||||
withLegacyRouting(t)
|
||||
|
||||
// 198.51.100.x (TEST-NET-2) is not normally in any routing table.
|
||||
vpnRoutes := []netip.Prefix{
|
||||
netip.MustParsePrefix("10.0.0.0/8"),
|
||||
}
|
||||
|
||||
addr := netip.MustParseAddr("198.51.100.1")
|
||||
routed, _ := IsAddrRouted(addr, vpnRoutes)
|
||||
assert.False(t, routed, "address not in VPN routes should not be marked as VPN-routed")
|
||||
}
|
||||
|
||||
// TestIsAddrRouted_LegacyMode_EmptyVPNRoutes verifies that an empty vpnRoutes
|
||||
// slice always yields (false, zero prefix) even under legacy routing.
|
||||
func TestIsAddrRouted_LegacyMode_EmptyVPNRoutes(t *testing.T) {
|
||||
withLegacyRouting(t)
|
||||
|
||||
routed, prefix := IsAddrRouted(netip.MustParseAddr("10.0.0.1"), nil)
|
||||
assert.False(t, routed)
|
||||
assert.Equal(t, netip.Prefix{}, prefix)
|
||||
}
|
||||
|
||||
// TestIsAddrRouted_AdvancedVsLegacy_ContrastiveBehaviour documents the
|
||||
// contract difference between the two modes: with a VPN default route and an
|
||||
// address that matches it, legacy mode may mark it as VPN-routed while advanced
|
||||
// mode must never do so.
|
||||
func TestIsAddrRouted_AdvancedVsLegacy_ContrastiveBehaviour(t *testing.T) {
|
||||
vpnRoutes := []netip.Prefix{
|
||||
netip.MustParsePrefix("0.0.0.0/0"),
|
||||
}
|
||||
addr := netip.MustParseAddr("8.8.8.8")
|
||||
|
||||
// --- advanced routing: must always return false ---
|
||||
t.Run("advanced routing", func(t *testing.T) {
|
||||
withAdvancedRouting(t)
|
||||
routed, prefix := IsAddrRouted(addr, vpnRoutes)
|
||||
assert.False(t, routed, "advanced routing must bypass VPN route lookup")
|
||||
assert.Equal(t, netip.Prefix{}, prefix)
|
||||
})
|
||||
|
||||
// --- legacy routing: delegates to kernel table check, does not panic ---
|
||||
t.Run("legacy routing", func(t *testing.T) {
|
||||
withLegacyRouting(t)
|
||||
// We don't assert true/false here because it depends on the host
|
||||
// routing table, but the call must not panic and must return
|
||||
// a valid (bool, prefix) pair.
|
||||
routed, prefix := IsAddrRouted(addr, vpnRoutes)
|
||||
t.Logf("legacy IsAddrRouted(%s, %v) = (%v, %v)", addr, vpnRoutes, routed, prefix)
|
||||
})
|
||||
}
|
||||
@@ -22,6 +22,10 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return []netip.Prefix{}, nil
|
||||
}
|
||||
|
||||
// GetDetailedRoutesFromTable returns empty routes for WASM.
|
||||
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
|
||||
return []DetailedRoute{}, nil
|
||||
|
||||
@@ -894,6 +894,13 @@ func getAddressFamily(prefix netip.Prefix) int {
|
||||
return netlink.FAMILY_V6
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
if !nbnet.AdvancedRouting() {
|
||||
return GetRoutesFromTable()
|
||||
}
|
||||
return nil, ErrRoutingIsSeparate
|
||||
}
|
||||
|
||||
func isOpErr(err error) bool {
|
||||
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
|
||||
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {
|
||||
|
||||
@@ -48,6 +48,10 @@ func EnableIPForwarding() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func hasSeparateRouting() ([]netip.Prefix, error) {
|
||||
return GetRoutesFromTable()
|
||||
}
|
||||
|
||||
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
|
||||
func GetIPRules() ([]IPRule, error) {
|
||||
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)
|
||||
|
||||
@@ -25,9 +25,6 @@ import (
|
||||
|
||||
const (
|
||||
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
|
||||
|
||||
// routeBudget bounds retries for per-prefix exclusion route programming.
|
||||
routeBudget = 1 * time.Second
|
||||
)
|
||||
|
||||
var routeProtoFlag int
|
||||
@@ -44,42 +41,26 @@ func init() {
|
||||
}
|
||||
|
||||
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return r.setupAdvancedRouting()
|
||||
}
|
||||
|
||||
log.Infof("Using legacy routing setup with ref counters")
|
||||
return r.setupRefCounter(initAddresses, stateManager)
|
||||
}
|
||||
|
||||
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
|
||||
if advancedRouting {
|
||||
return r.cleanupAdvancedRouting()
|
||||
}
|
||||
|
||||
return r.cleanupRefCounter(stateManager)
|
||||
}
|
||||
|
||||
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
|
||||
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
|
||||
// crashed prior session can't leave crud in the table.
|
||||
func (r *SysOps) FlushMarkedRoutes() error {
|
||||
var merr *multierror.Error
|
||||
|
||||
if err := r.flushPlatformExtras(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
|
||||
}
|
||||
|
||||
rib, err := retryFetchRIB()
|
||||
if err != nil {
|
||||
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
|
||||
return fmt.Errorf("fetch routing table: %w", err)
|
||||
}
|
||||
|
||||
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
|
||||
if err != nil {
|
||||
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
|
||||
return fmt.Errorf("parse routing table: %w", err)
|
||||
}
|
||||
|
||||
var merr *multierror.Error
|
||||
flushedCount := 0
|
||||
|
||||
for _, msg := range msgs {
|
||||
@@ -136,12 +117,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
|
||||
return fmt.Errorf("invalid prefix: %s", prefix)
|
||||
}
|
||||
|
||||
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
||||
if err != nil {
|
||||
return fmt.Errorf("build route message: %w", err)
|
||||
}
|
||||
expBackOff := backoff.NewExponentialBackOff()
|
||||
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||
expBackOff.MaxElapsedTime = 1 * time.Second
|
||||
|
||||
if err := r.writeRouteMessage(msg, routeBudget); err != nil {
|
||||
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
|
||||
a := "add"
|
||||
if action == unix.RTM_DELETE {
|
||||
a = "remove"
|
||||
@@ -151,91 +132,50 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
|
||||
return nil
|
||||
}
|
||||
|
||||
// writeRouteMessage sends a route message over AF_ROUTE and waits for the
|
||||
// kernel's matching reply, retrying transient failures until budget elapses.
|
||||
// Callers do not need to manage sockets or seq numbers themselves.
|
||||
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
|
||||
expBackOff := backoff.NewExponentialBackOff()
|
||||
expBackOff.InitialInterval = 50 * time.Millisecond
|
||||
expBackOff.MaxInterval = 500 * time.Millisecond
|
||||
expBackOff.MaxElapsedTime = budget
|
||||
|
||||
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
|
||||
}
|
||||
|
||||
func routeMessageRoundtrip(msg *route.RouteMessage) error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
return fmt.Errorf("open routing socket: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("close routing socket: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
tv := unix.Timeval{Sec: 1}
|
||||
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
|
||||
}
|
||||
|
||||
// AF_ROUTE is a broadcast channel: every route socket on the host sees
|
||||
// every RTM_* event. With concurrent route programming the default
|
||||
// per-socket queue overflows and our own reply gets dropped.
|
||||
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
|
||||
log.Debugf("set SO_RCVBUF on route socket: %v", err)
|
||||
}
|
||||
|
||||
bytes, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
|
||||
}
|
||||
|
||||
if _, err = unix.Write(fd, bytes); err != nil {
|
||||
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||
}
|
||||
return readRouteResponse(fd, msg.Type, msg.Seq)
|
||||
}
|
||||
|
||||
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
|
||||
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
|
||||
// broadcast channel: interface up/down, third-party route changes and neighbor
|
||||
// discovery events can all land between our write and read, so we must filter.
|
||||
func readRouteResponse(fd, wantType, wantSeq int) error {
|
||||
pid := int32(os.Getpid())
|
||||
resp := make([]byte, 2048)
|
||||
deadline := time.Now().Add(time.Second)
|
||||
for {
|
||||
if time.Now().After(deadline) {
|
||||
// Transient: under concurrent pressure the kernel can drop our reply
|
||||
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
|
||||
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
|
||||
}
|
||||
n, err := unix.Read(fd, resp)
|
||||
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
|
||||
operation := func() error {
|
||||
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
|
||||
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
|
||||
continue
|
||||
return fmt.Errorf("open routing socket: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
|
||||
log.Warnf("failed to close routing socket: %v", err)
|
||||
}
|
||||
return backoff.Permanent(fmt.Errorf("read: %w", err))
|
||||
}()
|
||||
|
||||
msg, err := r.buildRouteMessage(action, prefix, nexthop)
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
|
||||
}
|
||||
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||
continue
|
||||
|
||||
msgBytes, err := msg.Marshal()
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
|
||||
}
|
||||
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
|
||||
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
|
||||
// uniquely identifies our own reply among broadcast traffic.
|
||||
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
|
||||
continue
|
||||
|
||||
if _, err = unix.Write(fd, msgBytes); err != nil {
|
||||
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
|
||||
return fmt.Errorf("write: %w", err)
|
||||
}
|
||||
return backoff.Permanent(fmt.Errorf("write: %w", err))
|
||||
}
|
||||
if hdr.Errno != 0 {
|
||||
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
|
||||
|
||||
respBuf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, respBuf)
|
||||
if err != nil {
|
||||
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
|
||||
}
|
||||
|
||||
if n > 0 {
|
||||
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
|
||||
return backoff.Permanent(err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
return operation
|
||||
}
|
||||
|
||||
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
|
||||
@@ -243,7 +183,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
||||
Type: action,
|
||||
Flags: unix.RTF_UP | routeProtoFlag,
|
||||
Version: unix.RTM_VERSION,
|
||||
ID: uintptr(os.Getpid()),
|
||||
Seq: r.getSeq(),
|
||||
}
|
||||
|
||||
@@ -282,6 +221,19 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
|
||||
return msg, nil
|
||||
}
|
||||
|
||||
func (r *SysOps) parseRouteResponse(buf []byte) error {
|
||||
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
|
||||
return nil
|
||||
}
|
||||
|
||||
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
if rtMsg.Errno != 0 {
|
||||
return fmt.Errorf("parse: %d", rtMsg.Errno)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
|
||||
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
|
||||
if addr.Is4() {
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package net
|
||||
|
||||
func (d *Dialer) init() {
|
||||
d.Dialer.Control = applyBoundIfToSocket
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !windows && !darwin
|
||||
//go:build !linux && !windows
|
||||
|
||||
package net
|
||||
|
||||
|
||||
24
client/net/env_android.go
Normal file
24
client/net/env_android.go
Normal file
@@ -0,0 +1,24 @@
|
||||
//go:build android
|
||||
|
||||
package net
|
||||
|
||||
// Init initializes the network environment for Android
|
||||
func Init() {
|
||||
// No initialization needed on Android
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
||||
// Always returns true on Android since we cannot handle routes dynamically.
|
||||
func AdvancedRouting() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on Android
|
||||
func SetVPNInterfaceName(name string) {
|
||||
// No-op on Android - not needed for Android VPN service
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns empty string on Android
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
121
client/net/env_bound_iface_test.go
Normal file
121
client/net/env_bound_iface_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
//go:build (darwin && !ios) || windows
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// resetAdvancedRoutingState resets the package-level advancedRoutingSupported variable
|
||||
// and cleans up after the test.
|
||||
func resetAdvancedRoutingState(t *testing.T) {
|
||||
t.Helper()
|
||||
orig := advancedRoutingSupported
|
||||
t.Cleanup(func() { advancedRoutingSupported = orig })
|
||||
}
|
||||
|
||||
// TestCheckAdvancedRoutingSupport_LegacyRoutingTrue verifies that setting
|
||||
// NB_USE_LEGACY_ROUTING=true disables advanced routing.
|
||||
func TestCheckAdvancedRoutingSupport_LegacyRoutingTrue(t *testing.T) {
|
||||
t.Setenv(envUseLegacyRouting, "true")
|
||||
assert.False(t, checkAdvancedRoutingSupport())
|
||||
}
|
||||
|
||||
// TestCheckAdvancedRoutingSupport_LegacyRoutingFalse verifies that
|
||||
// NB_USE_LEGACY_ROUTING=false still allows advanced routing when netstack is off.
|
||||
func TestCheckAdvancedRoutingSupport_LegacyRoutingFalse(t *testing.T) {
|
||||
t.Setenv(envUseLegacyRouting, "false")
|
||||
t.Setenv("NB_USE_NETSTACK_MODE", "false")
|
||||
assert.True(t, checkAdvancedRoutingSupport())
|
||||
}
|
||||
|
||||
// TestCheckAdvancedRoutingSupport_LegacyRoutingInvalid verifies that an invalid
|
||||
// value for NB_USE_LEGACY_ROUTING is ignored (treated as false), so advanced
|
||||
// routing remains enabled when netstack is off.
|
||||
func TestCheckAdvancedRoutingSupport_LegacyRoutingInvalid(t *testing.T) {
|
||||
t.Setenv(envUseLegacyRouting, "notabool")
|
||||
t.Setenv("NB_USE_NETSTACK_MODE", "false")
|
||||
// The invalid value is ignored; the default (false) is kept, so advanced routing
|
||||
// is not suppressed by the legacy-routing flag.
|
||||
assert.True(t, checkAdvancedRoutingSupport())
|
||||
}
|
||||
|
||||
// TestCheckAdvancedRoutingSupport_NetstackEnabled verifies that netstack mode
|
||||
// disables advanced routing.
|
||||
func TestCheckAdvancedRoutingSupport_NetstackEnabled(t *testing.T) {
|
||||
t.Setenv(envUseLegacyRouting, "false")
|
||||
t.Setenv("NB_USE_NETSTACK_MODE", "true")
|
||||
assert.False(t, checkAdvancedRoutingSupport())
|
||||
}
|
||||
|
||||
// TestCheckAdvancedRoutingSupport_NoEnvVars verifies that with no env overrides
|
||||
// advanced routing is supported (the happy path).
|
||||
func TestCheckAdvancedRoutingSupport_NoEnvVars(t *testing.T) {
|
||||
// Unset both controlling variables so we hit the default path.
|
||||
t.Setenv(envUseLegacyRouting, "")
|
||||
t.Setenv("NB_USE_NETSTACK_MODE", "false")
|
||||
assert.True(t, checkAdvancedRoutingSupport())
|
||||
}
|
||||
|
||||
// TestCheckAdvancedRoutingSupport_LegacyRoutingEmptyString verifies that an
|
||||
// empty NB_USE_LEGACY_ROUTING is treated as "not set" and does not disable
|
||||
// advanced routing.
|
||||
func TestCheckAdvancedRoutingSupport_LegacyRoutingEmptyString(t *testing.T) {
|
||||
t.Setenv(envUseLegacyRouting, "")
|
||||
t.Setenv("NB_USE_NETSTACK_MODE", "false")
|
||||
assert.True(t, checkAdvancedRoutingSupport())
|
||||
}
|
||||
|
||||
// TestAdvancedRouting_ReflectsInit verifies that after calling Init() with
|
||||
// NB_USE_LEGACY_ROUTING=true, AdvancedRouting() returns false.
|
||||
func TestAdvancedRouting_ReflectsInit(t *testing.T) {
|
||||
resetAdvancedRoutingState(t)
|
||||
|
||||
t.Setenv(envUseLegacyRouting, "true")
|
||||
Init()
|
||||
|
||||
assert.False(t, AdvancedRouting(), "AdvancedRouting should return false after Init with legacy routing")
|
||||
}
|
||||
|
||||
// TestAdvancedRouting_ReflectsInit_Advanced verifies that after calling Init()
|
||||
// without legacy overrides, AdvancedRouting() returns true.
|
||||
func TestAdvancedRouting_ReflectsInit_Advanced(t *testing.T) {
|
||||
resetAdvancedRoutingState(t)
|
||||
|
||||
t.Setenv(envUseLegacyRouting, "false")
|
||||
t.Setenv("NB_USE_NETSTACK_MODE", "false")
|
||||
Init()
|
||||
|
||||
assert.True(t, AdvancedRouting(), "AdvancedRouting should return true after Init without legacy overrides")
|
||||
}
|
||||
|
||||
// TestSetAndGetVPNInterfaceName verifies SetVPNInterfaceName and GetVPNInterfaceName
|
||||
// are consistent.
|
||||
func TestSetAndGetVPNInterfaceName(t *testing.T) {
|
||||
orig := GetVPNInterfaceName()
|
||||
t.Cleanup(func() { SetVPNInterfaceName(orig) })
|
||||
|
||||
SetVPNInterfaceName("utun3")
|
||||
assert.Equal(t, "utun3", GetVPNInterfaceName())
|
||||
}
|
||||
|
||||
// TestSetVPNInterfaceName_Empty verifies that setting an empty name is accepted.
|
||||
func TestSetVPNInterfaceName_Empty(t *testing.T) {
|
||||
orig := GetVPNInterfaceName()
|
||||
t.Cleanup(func() { SetVPNInterfaceName(orig) })
|
||||
|
||||
SetVPNInterfaceName("")
|
||||
assert.Equal(t, "", GetVPNInterfaceName())
|
||||
}
|
||||
|
||||
// TestSetVPNInterfaceName_OverwritesPrevious verifies that the second call wins.
|
||||
func TestSetVPNInterfaceName_OverwritesPrevious(t *testing.T) {
|
||||
orig := GetVPNInterfaceName()
|
||||
t.Cleanup(func() { SetVPNInterfaceName(orig) })
|
||||
|
||||
SetVPNInterfaceName("utun1")
|
||||
SetVPNInterfaceName("utun9")
|
||||
assert.Equal(t, "utun9", GetVPNInterfaceName())
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !windows && !darwin
|
||||
//go:build !linux && !windows && !android
|
||||
|
||||
package net
|
||||
|
||||
|
||||
@@ -1,25 +0,0 @@
|
||||
//go:build ios || android
|
||||
|
||||
package net
|
||||
|
||||
// Init initializes the network environment for mobile platforms.
|
||||
func Init() {
|
||||
// no-op on mobile: routing scope is owned by the VPN extension.
|
||||
}
|
||||
|
||||
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
|
||||
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
|
||||
// owns the routing scope.
|
||||
func AdvancedRouting() bool {
|
||||
return true
|
||||
}
|
||||
|
||||
// SetVPNInterfaceName is a no-op on mobile.
|
||||
func SetVPNInterfaceName(string) {
|
||||
// no-op on mobile: the VPN extension manages the interface.
|
||||
}
|
||||
|
||||
// GetVPNInterfaceName returns an empty string on mobile.
|
||||
func GetVPNInterfaceName() string {
|
||||
return ""
|
||||
}
|
||||
39
client/net/env_mobile_test.go
Normal file
39
client/net/env_mobile_test.go
Normal file
@@ -0,0 +1,39 @@
|
||||
//go:build ios || android
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestAdvancedRouting_Mobile verifies that AdvancedRouting always returns true
|
||||
// on mobile platforms.
|
||||
func TestAdvancedRouting_Mobile(t *testing.T) {
|
||||
assert.True(t, AdvancedRouting(), "AdvancedRouting must always be true on mobile")
|
||||
}
|
||||
|
||||
// TestInit_Mobile verifies that Init is a no-op and does not panic.
|
||||
func TestInit_Mobile(t *testing.T) {
|
||||
// Should not panic.
|
||||
Init()
|
||||
// After Init, AdvancedRouting must still return true.
|
||||
assert.True(t, AdvancedRouting())
|
||||
}
|
||||
|
||||
// TestSetVPNInterfaceName_Mobile verifies that SetVPNInterfaceName is a no-op
|
||||
// and does not panic.
|
||||
func TestSetVPNInterfaceName_Mobile(t *testing.T) {
|
||||
// Should not panic for any input.
|
||||
SetVPNInterfaceName("utun0")
|
||||
SetVPNInterfaceName("")
|
||||
}
|
||||
|
||||
// TestGetVPNInterfaceName_Mobile verifies that GetVPNInterfaceName always
|
||||
// returns an empty string on mobile.
|
||||
func TestGetVPNInterfaceName_Mobile(t *testing.T) {
|
||||
// Even after a SetVPNInterfaceName call (no-op), the getter returns "".
|
||||
SetVPNInterfaceName("utun0")
|
||||
assert.Equal(t, "", GetVPNInterfaceName(), "GetVPNInterfaceName must return empty string on mobile")
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build (darwin && !ios) || windows
|
||||
//go:build windows
|
||||
|
||||
package net
|
||||
|
||||
@@ -24,22 +24,17 @@ func Init() {
|
||||
}
|
||||
|
||||
func checkAdvancedRoutingSupport() bool {
|
||||
legacyRouting := false
|
||||
var err error
|
||||
var legacyRouting bool
|
||||
if val := os.Getenv(envUseLegacyRouting); val != "" {
|
||||
parsed, err := strconv.ParseBool(val)
|
||||
legacyRouting, err = strconv.ParseBool(val)
|
||||
if err != nil {
|
||||
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
|
||||
} else {
|
||||
legacyRouting = parsed
|
||||
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
|
||||
}
|
||||
}
|
||||
|
||||
if legacyRouting {
|
||||
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
|
||||
return false
|
||||
}
|
||||
if netstack.IsEnabled() {
|
||||
log.Info("advanced routing disabled: netstack mode is enabled")
|
||||
if legacyRouting || netstack.IsEnabled() {
|
||||
log.Info("advanced routing has been requested to be disabled")
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -1,5 +0,0 @@
|
||||
package net
|
||||
|
||||
func (l *ListenerConfig) init() {
|
||||
l.ListenConfig.Control = applyBoundIfToSocket
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
//go:build !linux && !windows && !darwin
|
||||
//go:build !linux && !windows
|
||||
|
||||
package net
|
||||
|
||||
|
||||
@@ -1,160 +0,0 @@
|
||||
package net
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
|
||||
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
|
||||
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
|
||||
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
|
||||
// dispatched by socket domain (AF_INET only) not by inp_vflag.
|
||||
|
||||
// boundIface holds the physical interface chosen at routing setup time. Sockets
|
||||
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
|
||||
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
|
||||
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
|
||||
// following the VPN's split default.
|
||||
var (
|
||||
boundIfaceMu sync.RWMutex
|
||||
boundIface4 *net.Interface
|
||||
boundIface6 *net.Interface
|
||||
)
|
||||
|
||||
// SetBoundInterface records the egress interface for an address family. Called
|
||||
// by the routemanager after a scoped default route has been installed.
|
||||
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
|
||||
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
|
||||
func SetBoundInterface(af int, iface *net.Interface) {
|
||||
if iface == nil {
|
||||
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
|
||||
return
|
||||
}
|
||||
boundIfaceMu.Lock()
|
||||
defer boundIfaceMu.Unlock()
|
||||
switch af {
|
||||
case unix.AF_INET:
|
||||
boundIface4 = iface
|
||||
case unix.AF_INET6:
|
||||
boundIface6 = iface
|
||||
default:
|
||||
log.Warnf("SetBoundInterface: unsupported address family %d", af)
|
||||
}
|
||||
}
|
||||
|
||||
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
|
||||
// routemanager during cleanup.
|
||||
func ClearBoundInterfaces() {
|
||||
boundIfaceMu.Lock()
|
||||
defer boundIfaceMu.Unlock()
|
||||
boundIface4 = nil
|
||||
boundIface6 = nil
|
||||
}
|
||||
|
||||
// boundInterfaceFor returns the cached egress interface for a socket's address
|
||||
// family, falling back to the other family if the preferred slot is empty.
|
||||
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
|
||||
// either setsockopt scopes the socket; preferring same-family still matters
|
||||
// when v4 and v6 defaults egress different NICs.
|
||||
func boundInterfaceFor(network, address string) *net.Interface {
|
||||
if iface := zoneInterface(address); iface != nil {
|
||||
return iface
|
||||
}
|
||||
|
||||
boundIfaceMu.RLock()
|
||||
defer boundIfaceMu.RUnlock()
|
||||
|
||||
primary, secondary := boundIface4, boundIface6
|
||||
if isV6Network(network) {
|
||||
primary, secondary = boundIface6, boundIface4
|
||||
}
|
||||
if primary != nil {
|
||||
return primary
|
||||
}
|
||||
return secondary
|
||||
}
|
||||
|
||||
func isV6Network(network string) bool {
|
||||
return strings.HasSuffix(network, "6")
|
||||
}
|
||||
|
||||
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
|
||||
func zoneInterface(address string) *net.Interface {
|
||||
if address == "" {
|
||||
return nil
|
||||
}
|
||||
addr, err := netip.ParseAddrPort(address)
|
||||
if err != nil {
|
||||
a, err := netip.ParseAddr(address)
|
||||
if err != nil {
|
||||
return nil
|
||||
}
|
||||
addr = netip.AddrPortFrom(a, 0)
|
||||
}
|
||||
zone := addr.Addr().Zone()
|
||||
if zone == "" {
|
||||
return nil
|
||||
}
|
||||
if iface, err := net.InterfaceByName(zone); err == nil {
|
||||
return iface
|
||||
}
|
||||
if idx, err := strconv.Atoi(zone); err == nil {
|
||||
if iface, err := net.InterfaceByIndex(idx); err == nil {
|
||||
return iface
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
|
||||
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
|
||||
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
|
||||
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// applyBoundIfToSocket binds the socket to the cached physical egress interface
|
||||
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
|
||||
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
|
||||
if !AdvancedRouting() {
|
||||
return nil
|
||||
}
|
||||
|
||||
iface := boundInterfaceFor(network, address)
|
||||
if iface == nil {
|
||||
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
|
||||
return nil
|
||||
}
|
||||
|
||||
isV6 := isV6Network(network)
|
||||
var controlErr error
|
||||
if err := c.Control(func(fd uintptr) {
|
||||
if isV6 {
|
||||
controlErr = setIPv6BoundIf(fd, iface)
|
||||
} else {
|
||||
controlErr = setIPv4BoundIf(fd, iface)
|
||||
}
|
||||
if controlErr == nil {
|
||||
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
|
||||
}
|
||||
}); err != nil {
|
||||
return fmt.Errorf("control: %w", err)
|
||||
}
|
||||
return controlErr
|
||||
}
|
||||
286
client/net/net_darwin_test.go
Normal file
286
client/net/net_darwin_test.go
Normal file
@@ -0,0 +1,286 @@
|
||||
//go:build darwin && !ios
|
||||
|
||||
package net
|
||||
|
||||
import (
|
||||
"net"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"golang.org/x/sys/unix"
|
||||
)
|
||||
|
||||
// resetBoundIfaces clears global state before each test that manipulates it.
|
||||
func resetBoundIfaces(t *testing.T) {
|
||||
t.Helper()
|
||||
ClearBoundInterfaces()
|
||||
t.Cleanup(ClearBoundInterfaces)
|
||||
}
|
||||
|
||||
// TestIsV6Network verifies the isV6Network helper correctly identifies v6 networks.
|
||||
func TestIsV6Network(t *testing.T) {
|
||||
tests := []struct {
|
||||
network string
|
||||
want bool
|
||||
}{
|
||||
{"tcp6", true},
|
||||
{"udp6", true},
|
||||
{"ip6", true},
|
||||
{"tcp", false},
|
||||
{"udp", false},
|
||||
{"tcp4", false},
|
||||
{"udp4", false},
|
||||
{"ip4", false},
|
||||
{"ip", false},
|
||||
{"", false},
|
||||
// Arbitrary suffix-6 strings
|
||||
{"unix6", true},
|
||||
{"custom6", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.network, func(t *testing.T) {
|
||||
assert.Equal(t, tt.want, isV6Network(tt.network))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestSetBoundInterface_AFINET sets boundIface4 via AF_INET.
|
||||
func TestSetBoundInterface_AFINET(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
iface := &net.Interface{Index: 5, Name: "en0"}
|
||||
SetBoundInterface(unix.AF_INET, iface)
|
||||
|
||||
boundIfaceMu.RLock()
|
||||
got := boundIface4
|
||||
boundIfaceMu.RUnlock()
|
||||
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, 5, got.Index)
|
||||
assert.Equal(t, "en0", got.Name)
|
||||
}
|
||||
|
||||
// TestSetBoundInterface_AFINET6 sets boundIface6 via AF_INET6.
|
||||
func TestSetBoundInterface_AFINET6(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
iface := &net.Interface{Index: 7, Name: "utun0"}
|
||||
SetBoundInterface(unix.AF_INET6, iface)
|
||||
|
||||
boundIfaceMu.RLock()
|
||||
got := boundIface6
|
||||
boundIfaceMu.RUnlock()
|
||||
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, 7, got.Index)
|
||||
assert.Equal(t, "utun0", got.Name)
|
||||
}
|
||||
|
||||
// TestSetBoundInterface_Nil verifies nil iface is rejected without panicking.
|
||||
func TestSetBoundInterface_Nil(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
// Should not panic, just log a warning and leave existing values untouched.
|
||||
iface := &net.Interface{Index: 1, Name: "en1"}
|
||||
SetBoundInterface(unix.AF_INET, iface)
|
||||
|
||||
SetBoundInterface(unix.AF_INET, nil)
|
||||
|
||||
boundIfaceMu.RLock()
|
||||
got := boundIface4
|
||||
boundIfaceMu.RUnlock()
|
||||
|
||||
// Value must be unchanged after the rejected nil write.
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, 1, got.Index)
|
||||
}
|
||||
|
||||
// TestSetBoundInterface_UnknownAF verifies unknown address families are ignored.
|
||||
func TestSetBoundInterface_UnknownAF(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
iface := &net.Interface{Index: 3, Name: "en2"}
|
||||
// Use an address family that is not AF_INET or AF_INET6.
|
||||
SetBoundInterface(99, iface)
|
||||
|
||||
boundIfaceMu.RLock()
|
||||
v4, v6 := boundIface4, boundIface6
|
||||
boundIfaceMu.RUnlock()
|
||||
|
||||
assert.Nil(t, v4, "unknown AF must not populate boundIface4")
|
||||
assert.Nil(t, v6, "unknown AF must not populate boundIface6")
|
||||
}
|
||||
|
||||
// TestClearBoundInterfaces clears both cached interfaces.
|
||||
func TestClearBoundInterfaces(t *testing.T) {
|
||||
iface4 := &net.Interface{Index: 1, Name: "en0"}
|
||||
iface6 := &net.Interface{Index: 2, Name: "en0"}
|
||||
|
||||
SetBoundInterface(unix.AF_INET, iface4)
|
||||
SetBoundInterface(unix.AF_INET6, iface6)
|
||||
|
||||
ClearBoundInterfaces()
|
||||
|
||||
boundIfaceMu.RLock()
|
||||
v4, v6 := boundIface4, boundIface6
|
||||
boundIfaceMu.RUnlock()
|
||||
|
||||
assert.Nil(t, v4, "boundIface4 must be nil after clear")
|
||||
assert.Nil(t, v6, "boundIface6 must be nil after clear")
|
||||
}
|
||||
|
||||
// TestClearBoundInterfaces_Idempotent verifies clearing twice does not panic.
|
||||
func TestClearBoundInterfaces_Idempotent(t *testing.T) {
|
||||
ClearBoundInterfaces()
|
||||
ClearBoundInterfaces()
|
||||
|
||||
boundIfaceMu.RLock()
|
||||
v4, v6 := boundIface4, boundIface6
|
||||
boundIfaceMu.RUnlock()
|
||||
|
||||
assert.Nil(t, v4)
|
||||
assert.Nil(t, v6)
|
||||
}
|
||||
|
||||
// TestBoundInterfaceFor_PreferSameFamily verifies v4 iface returned for "tcp" and
|
||||
// v6 iface returned for "tcp6" when both slots are populated.
|
||||
func TestBoundInterfaceFor_PreferSameFamily(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
en0 := &net.Interface{Index: 1, Name: "en0"}
|
||||
en1 := &net.Interface{Index: 2, Name: "en1"}
|
||||
SetBoundInterface(unix.AF_INET, en0)
|
||||
SetBoundInterface(unix.AF_INET6, en1)
|
||||
|
||||
got4 := boundInterfaceFor("tcp", "1.2.3.4:80")
|
||||
require.NotNil(t, got4)
|
||||
assert.Equal(t, "en0", got4.Name, "tcp should prefer v4 interface")
|
||||
|
||||
got6 := boundInterfaceFor("tcp6", "[::1]:80")
|
||||
require.NotNil(t, got6)
|
||||
assert.Equal(t, "en1", got6.Name, "tcp6 should prefer v6 interface")
|
||||
}
|
||||
|
||||
// TestBoundInterfaceFor_FallbackToOtherFamily returns the other family's iface
|
||||
// when the preferred slot is empty.
|
||||
func TestBoundInterfaceFor_FallbackToOtherFamily(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
// Only v4 populated.
|
||||
en0 := &net.Interface{Index: 1, Name: "en0"}
|
||||
SetBoundInterface(unix.AF_INET, en0)
|
||||
|
||||
// Asking for v6 should fall back to en0.
|
||||
got := boundInterfaceFor("tcp6", "[::1]:80")
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, "en0", got.Name)
|
||||
}
|
||||
|
||||
// TestBoundInterfaceFor_BothEmpty returns nil when both slots are empty.
|
||||
func TestBoundInterfaceFor_BothEmpty(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
got := boundInterfaceFor("tcp", "1.2.3.4:80")
|
||||
assert.Nil(t, got)
|
||||
|
||||
got6 := boundInterfaceFor("tcp6", "[::1]:80")
|
||||
assert.Nil(t, got6)
|
||||
}
|
||||
|
||||
// TestZoneInterface_Empty returns nil for empty address.
|
||||
func TestZoneInterface_Empty(t *testing.T) {
|
||||
iface := zoneInterface("")
|
||||
assert.Nil(t, iface)
|
||||
}
|
||||
|
||||
// TestZoneInterface_NoZone returns nil when address has no zone.
|
||||
func TestZoneInterface_NoZone(t *testing.T) {
|
||||
// Regular IPv4 address with port — no zone identifier.
|
||||
iface := zoneInterface("192.168.1.1:80")
|
||||
assert.Nil(t, iface)
|
||||
|
||||
// Regular IPv6 address with port — no zone identifier.
|
||||
iface = zoneInterface("[2001:db8::1]:80")
|
||||
assert.Nil(t, iface)
|
||||
|
||||
// Plain IPv6 address without port or zone.
|
||||
iface = zoneInterface("2001:db8::1")
|
||||
assert.Nil(t, iface)
|
||||
}
|
||||
|
||||
// TestZoneInterface_InvalidAddress returns nil for completely invalid strings.
|
||||
func TestZoneInterface_InvalidAddress(t *testing.T) {
|
||||
iface := zoneInterface("not-an-address")
|
||||
assert.Nil(t, iface)
|
||||
|
||||
iface = zoneInterface("::::")
|
||||
assert.Nil(t, iface)
|
||||
}
|
||||
|
||||
// TestZoneInterface_NonExistentZoneName returns nil for a zone name that does
|
||||
// not correspond to a real interface on the host.
|
||||
func TestZoneInterface_NonExistentZoneName(t *testing.T) {
|
||||
// Use an interface name that is very unlikely to exist.
|
||||
iface := zoneInterface("fe80::1%nonexistentiface99999")
|
||||
assert.Nil(t, iface)
|
||||
}
|
||||
|
||||
// TestZoneInterface_NonExistentZoneIndex returns nil for a zone expressed as an
|
||||
// integer index that is not in use.
|
||||
func TestZoneInterface_NonExistentZoneIndex(t *testing.T) {
|
||||
// Interface index 999999 should not exist on any test machine.
|
||||
iface := zoneInterface("fe80::1%999999")
|
||||
assert.Nil(t, iface)
|
||||
}
|
||||
|
||||
// TestBoundInterfaceFor_SetThenClear verifies that clearing state causes
|
||||
// boundInterfaceFor to return nil afterwards.
|
||||
func TestBoundInterfaceFor_SetThenClear(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
en0 := &net.Interface{Index: 1, Name: "en0"}
|
||||
SetBoundInterface(unix.AF_INET, en0)
|
||||
|
||||
got := boundInterfaceFor("tcp", "1.2.3.4:80")
|
||||
require.NotNil(t, got, "should return iface while set")
|
||||
|
||||
ClearBoundInterfaces()
|
||||
|
||||
got = boundInterfaceFor("tcp", "1.2.3.4:80")
|
||||
assert.Nil(t, got, "should return nil after clear")
|
||||
}
|
||||
|
||||
// TestSetBoundInterface_OverwritesPreviousValue verifies that calling
|
||||
// SetBoundInterface again updates the stored pointer.
|
||||
func TestSetBoundInterface_OverwritesPreviousValue(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
first := &net.Interface{Index: 1, Name: "en0"}
|
||||
second := &net.Interface{Index: 3, Name: "en1"}
|
||||
|
||||
SetBoundInterface(unix.AF_INET, first)
|
||||
SetBoundInterface(unix.AF_INET, second)
|
||||
|
||||
boundIfaceMu.RLock()
|
||||
got := boundIface4
|
||||
boundIfaceMu.RUnlock()
|
||||
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, "en1", got.Name, "second call should overwrite first")
|
||||
}
|
||||
|
||||
// TestBoundInterfaceFor_OnlyV6Populated returns v6 iface for v4 network
|
||||
// when only v6 slot is filled.
|
||||
func TestBoundInterfaceFor_OnlyV6Populated(t *testing.T) {
|
||||
resetBoundIfaces(t)
|
||||
|
||||
en1 := &net.Interface{Index: 2, Name: "en1"}
|
||||
SetBoundInterface(unix.AF_INET6, en1)
|
||||
|
||||
// v4 network, v4 slot empty → should fall back to v6 slot.
|
||||
got := boundInterfaceFor("tcp", "1.2.3.4:80")
|
||||
require.NotNil(t, got)
|
||||
assert.Equal(t, "en1", got.Name)
|
||||
}
|
||||
@@ -12,6 +12,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
@@ -137,8 +138,10 @@ func restoreResidualState(ctx context.Context, statePath string) error {
|
||||
}
|
||||
|
||||
// clean up any remaining routes independently of the state file
|
||||
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||
if !nbnet.AdvancedRouting() {
|
||||
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
|
||||
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
|
||||
}
|
||||
}
|
||||
|
||||
return nberrors.FormatErrorOrNil(merr)
|
||||
|
||||
@@ -187,23 +187,24 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
|
||||
return "", fmt.Errorf("get NetBird executable path: %w", err)
|
||||
}
|
||||
|
||||
hostList := strings.Join(deduplicatedPatterns, ",")
|
||||
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||
config += " StrictHostKeyChecking no\n"
|
||||
hostLine := strings.Join(deduplicatedPatterns, " ")
|
||||
config := fmt.Sprintf("Host %s\n", hostLine)
|
||||
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
|
||||
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
|
||||
config += " PasswordAuthentication yes\n"
|
||||
config += " PubkeyAuthentication yes\n"
|
||||
config += " BatchMode no\n"
|
||||
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
|
||||
config += " StrictHostKeyChecking no\n"
|
||||
|
||||
if runtime.GOOS == "windows" {
|
||||
config += " UserKnownHostsFile NUL\n"
|
||||
config += " UserKnownHostsFile NUL\n"
|
||||
} else {
|
||||
config += " UserKnownHostsFile /dev/null\n"
|
||||
config += " UserKnownHostsFile /dev/null\n"
|
||||
}
|
||||
|
||||
config += " CheckHostIP no\n"
|
||||
config += " LogLevel ERROR\n\n"
|
||||
config += " CheckHostIP no\n"
|
||||
config += " LogLevel ERROR\n\n"
|
||||
|
||||
return config, nil
|
||||
}
|
||||
|
||||
@@ -116,37 +116,6 @@ func TestManager_PeerLimit(t *testing.T) {
|
||||
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
|
||||
}
|
||||
|
||||
func TestManager_MatchHostFormat(t *testing.T) {
|
||||
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
|
||||
require.NoError(t, err)
|
||||
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
|
||||
|
||||
manager := &Manager{
|
||||
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
|
||||
sshConfigFile: "99-netbird.conf",
|
||||
}
|
||||
|
||||
peers := []PeerSSHInfo{
|
||||
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
|
||||
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
|
||||
}
|
||||
|
||||
err = manager.SetupSSHClientConfig(peers)
|
||||
require.NoError(t, err)
|
||||
|
||||
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
|
||||
content, err := os.ReadFile(configPath)
|
||||
require.NoError(t, err)
|
||||
configStr := string(content)
|
||||
|
||||
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
|
||||
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
|
||||
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
|
||||
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
|
||||
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
|
||||
"should use Match host with comma-separated patterns")
|
||||
}
|
||||
|
||||
func TestManager_ForcedSSHConfig(t *testing.T) {
|
||||
// Set force environment variable
|
||||
t.Setenv(EnvForceSSHConfig, "true")
|
||||
|
||||
@@ -2,6 +2,7 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"strings"
|
||||
|
||||
@@ -144,6 +145,59 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
|
||||
return v
|
||||
}
|
||||
|
||||
func networkAddresses() ([]NetworkAddress, error) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var netAddresses []NetworkAddress
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if iface.HardwareAddr.String() == "" {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, address := range addrs {
|
||||
ipNet, ok := address.(*net.IPNet)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
|
||||
if ipNet.IP.IsLoopback() {
|
||||
continue
|
||||
}
|
||||
|
||||
netAddr := NetworkAddress{
|
||||
NetIP: netip.MustParsePrefix(ipNet.String()),
|
||||
Mac: iface.HardwareAddr.String(),
|
||||
}
|
||||
|
||||
if isDuplicated(netAddresses, netAddr) {
|
||||
continue
|
||||
}
|
||||
|
||||
netAddresses = append(netAddresses, netAddr)
|
||||
}
|
||||
}
|
||||
return netAddresses, nil
|
||||
}
|
||||
|
||||
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
for _, duplicated := range addresses {
|
||||
if duplicated.NetIP == addr.NetIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// GetInfoWithChecks retrieves and parses the system information with applied checks.
|
||||
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
|
||||
log.Debugf("gathering system information with checks: %d", len(checks))
|
||||
|
||||
@@ -2,8 +2,6 @@ package system
|
||||
|
||||
import (
|
||||
"context"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -44,66 +42,6 @@ func GetInfo(ctx context.Context) *Info {
|
||||
return gio
|
||||
}
|
||||
|
||||
// networkAddresses returns the list of network addresses on iOS.
|
||||
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
|
||||
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
|
||||
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
|
||||
// check that other platforms use and filter out link-local addresses as they
|
||||
// are not useful for posture checks.
|
||||
func networkAddresses() ([]NetworkAddress, error) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var netAddresses []NetworkAddress
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
for _, address := range addrs {
|
||||
netAddr, ok := toNetworkAddress(address)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDuplicated(netAddresses, netAddr) {
|
||||
continue
|
||||
}
|
||||
netAddresses = append(netAddresses, netAddr)
|
||||
}
|
||||
}
|
||||
return netAddresses, nil
|
||||
}
|
||||
|
||||
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
|
||||
ipNet, ok := address.(*net.IPNet)
|
||||
if !ok {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||
if err != nil {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
return NetworkAddress{NetIP: prefix, Mac: ""}, true
|
||||
}
|
||||
|
||||
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
for _, duplicated := range addresses {
|
||||
if duplicated.NetIP == addr.NetIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
|
||||
func checkFileAndProcess(paths []string) ([]File, error) {
|
||||
return []File{}, nil
|
||||
|
||||
@@ -1,66 +0,0 @@
|
||||
//go:build !ios
|
||||
|
||||
package system
|
||||
|
||||
import (
|
||||
"net"
|
||||
"net/netip"
|
||||
)
|
||||
|
||||
func networkAddresses() ([]NetworkAddress, error) {
|
||||
interfaces, err := net.Interfaces()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
var netAddresses []NetworkAddress
|
||||
for _, iface := range interfaces {
|
||||
if iface.Flags&net.FlagUp == 0 {
|
||||
continue
|
||||
}
|
||||
if iface.HardwareAddr.String() == "" {
|
||||
continue
|
||||
}
|
||||
addrs, err := iface.Addrs()
|
||||
if err != nil {
|
||||
continue
|
||||
}
|
||||
|
||||
mac := iface.HardwareAddr.String()
|
||||
for _, address := range addrs {
|
||||
netAddr, ok := toNetworkAddress(address, mac)
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if isDuplicated(netAddresses, netAddr) {
|
||||
continue
|
||||
}
|
||||
netAddresses = append(netAddresses, netAddr)
|
||||
}
|
||||
}
|
||||
return netAddresses, nil
|
||||
}
|
||||
|
||||
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
|
||||
ipNet, ok := address.(*net.IPNet)
|
||||
if !ok {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
if ipNet.IP.IsLoopback() {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
prefix, err := netip.ParsePrefix(ipNet.String())
|
||||
if err != nil {
|
||||
return NetworkAddress{}, false
|
||||
}
|
||||
return NetworkAddress{NetIP: prefix, Mac: mac}, true
|
||||
}
|
||||
|
||||
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
|
||||
for _, duplicated := range addresses {
|
||||
if duplicated.NetIP == addr.NetIP {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
Reference in New Issue
Block a user