Compare commits

..

1 Commits

Author SHA1 Message Date
coderabbitai[bot]
cb1eaf9e0d CodeRabbit Generated Unit Tests: Add unit tests for PR changes 2026-04-20 07:51:24 +00:00
40 changed files with 921 additions and 1497 deletions

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)

View File

@@ -8,7 +8,6 @@ import (
"os"
"slices"
"sync"
"time"
"golang.org/x/exp/maps"
@@ -16,7 +15,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -28,7 +26,6 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -71,30 +68,7 @@ type Client struct {
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
}
// NewClient instantiate a new Client
@@ -119,7 +93,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -151,9 +124,8 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -163,7 +135,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -186,9 +157,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// Stop the internal client and free the resources
@@ -203,12 +173,11 @@ func (c *Client) Stop() {
}
func (c *Client) RenewTun(fd int) error {
cc := c.getConnectClient()
if cc == nil {
if c.connectClient == nil {
return fmt.Errorf("engine not running")
}
e := cc.Engine()
e := c.connectClient.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
@@ -216,73 +185,6 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd)
}
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
@@ -312,13 +214,12 @@ func (c *Client) PeersList() *PeerInfoArray {
}
func (c *Client) Networks() *NetworkArray {
cc := c.getConnectClient()
if cc == nil {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := cc.Engine()
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
@@ -399,7 +300,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.getConnectClient()
client := c.connectClient
if client == nil {
return nil, fmt.Errorf("not connected")
}

View File

@@ -7,5 +7,4 @@ package android
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
CacheDir() string
}

View File

@@ -1,434 +0,0 @@
//go:build android
package android
import (
"context"
"errors"
"fmt"
"io"
"net"
"strconv"
"sync"
"time"
log "github.com/sirupsen/logrus"
gossh "golang.org/x/crypto/ssh"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/auth"
"github.com/netbirdio/netbird/client/internal/profilemanager"
nbssh "github.com/netbirdio/netbird/client/ssh"
"github.com/netbirdio/netbird/client/ssh/detection"
)
const (
sshDialTimeout = 30 * time.Second
sshDetectionTimeout = 5 * time.Second
)
// SSHTerminalListener receives SSH session events. It is implemented in Java.
//
// All callbacks are invoked from goroutines and may run concurrently with each
// other; the implementation must be safe to call from any thread.
type SSHTerminalListener interface {
OnConnected()
OnData(data []byte)
OnClose(reason string)
OnError(message string)
}
// SSHClient is a NetBird-aware SSH client exposed to Java via gomobile.
//
// It dials through the running NetBird tunnel and runs a standard SSH session
// on top with PTY enabled. Host-key verification uses the NetBird-provided
// peer SSH host keys, identical to the desktop client.
type SSHClient struct {
nb *Client
mu sync.Mutex
listener SSHTerminalListener
urlOpener URLOpener
sshClient *gossh.Client
session *gossh.Session
stdin io.WriteCloser
closed bool
}
// NewSSHClient creates a new SSH client bound to the running NetBird Client.
func NewSSHClient(c *Client) *SSHClient {
return &SSHClient{nb: c}
}
// SetListener registers the Java listener. Must be called before Connect to
// receive any events.
func (s *SSHClient) SetListener(l SSHTerminalListener) {
s.mu.Lock()
s.listener = l
s.mu.Unlock()
}
// SetURLOpener registers the Java URL opener used to display the device-code
// authorization page in a Custom Tabs window when the target peer requires
// JWT authentication. Must be set before Connect to be effective.
func (s *SSHClient) SetURLOpener(opener URLOpener) {
s.mu.Lock()
s.urlOpener = opener
s.mu.Unlock()
}
// Connect dials the SSH server through the NetBird tunnel and performs the
// SSH handshake. It auto-detects the server type via SSH banner inspection
// and selects the appropriate authentication path:
//
// - NetBird-SSH server requiring JWT: launches the OAuth 2.0 device-code
// flow, opens the verification URL through the registered URLOpener, and
// uses the resulting token as the SSH password. Host-key verification
// uses the NetBird peer registry.
// - NetBird-SSH server without JWT: authenticates with the NetBird SSH
// private key. Host-key verification uses the NetBird peer registry.
// - Regular SSH server (e.g. OpenSSH): authenticates with the NetBird key
// first (so a user-installed NetBird public key works), then falls back
// to the supplied password if non-empty. Host-key verification is
// disabled (TOFU pending).
//
// The password parameter is only consulted for regular SSH servers.
func (s *SSHClient) Connect(host string, port int, user, password string) error {
cfg, _, cc := s.nb.stateSnapshot()
if cc == nil {
return errors.New("netbird client not running")
}
if cfg == nil {
return errors.New("netbird config not loaded")
}
engine := cc.Engine()
if engine == nil {
return errors.New("netbird engine not available")
}
serverType := detectServerType(host, port)
log.Infof("SSH server type for %s:%d: %s", host, port, serverType)
authMethods, hostKeyCallback, err := s.buildAuth(cfg, engine, serverType, password)
if err != nil {
return err
}
clientConfig := &gossh.ClientConfig{
User: user,
Auth: authMethods,
HostKeyCallback: hostKeyCallback,
Timeout: sshDialTimeout,
}
return s.dialAndHandshake(host, port, clientConfig)
}
// StartSession requests a PTY and starts an interactive shell. Output from
// the session is forwarded to the listener via OnData.
func (s *SSHClient) StartSession(cols, rows int) error {
log.Debugf("SSH: starting session %dx%d", cols, rows)
s.mu.Lock()
sshClient := s.sshClient
s.mu.Unlock()
if sshClient == nil {
return errors.New("ssh client not connected")
}
session, err := sshClient.NewSession()
if err != nil {
return fmt.Errorf("new session: %w", err)
}
modes := gossh.TerminalModes{
gossh.ECHO: 1,
gossh.TTY_OP_ISPEED: 14400,
gossh.TTY_OP_OSPEED: 14400,
gossh.VINTR: 3,
gossh.VQUIT: 28,
gossh.VERASE: 127,
}
if err := session.RequestPty("xterm-256color", rows, cols, modes); err != nil {
closeQuiet(session, "session after pty error")
return fmt.Errorf("request pty: %w", err)
}
stdin, err := session.StdinPipe()
if err != nil {
closeQuiet(session, "session after stdin error")
return fmt.Errorf("stdin pipe: %w", err)
}
stdout, err := session.StdoutPipe()
if err != nil {
closeQuiet(session, "session after stdout error")
return fmt.Errorf("stdout pipe: %w", err)
}
stderr, err := session.StderrPipe()
if err != nil {
closeQuiet(session, "session after stderr error")
return fmt.Errorf("stderr pipe: %w", err)
}
if err := session.Shell(); err != nil {
closeQuiet(session, "session after shell error")
return fmt.Errorf("start shell: %w", err)
}
s.mu.Lock()
s.session = session
s.stdin = stdin
s.mu.Unlock()
go s.readLoop(stdout, "stdout")
go s.readLoop(stderr, "stderr")
log.Debug("SSH: session started, shell running")
return nil
}
// Write sends data to the SSH session stdin.
func (s *SSHClient) Write(data []byte) error {
s.mu.Lock()
stdin := s.stdin
s.mu.Unlock()
if stdin == nil {
return errors.New("ssh session not started")
}
if _, err := stdin.Write(data); err != nil {
return fmt.Errorf("write stdin: %w", err)
}
return nil
}
// Resize updates the PTY window size.
func (s *SSHClient) Resize(cols, rows int) error {
s.mu.Lock()
session := s.session
s.mu.Unlock()
if session == nil {
return errors.New("ssh session not started")
}
return session.WindowChange(rows, cols)
}
// Close terminates the SSH session and underlying connection. Safe to call
// multiple times.
func (s *SSHClient) Close() error {
s.mu.Lock()
sshClient := s.sshClient
session := s.session
stdin := s.stdin
s.sshClient = nil
s.session = nil
s.stdin = nil
s.mu.Unlock()
if stdin != nil {
if err := stdin.Close(); err != nil {
log.Debugf("ssh: stdin close: %v", err)
}
}
if session != nil {
if err := session.Close(); err != nil && !errors.Is(err, io.EOF) {
log.Debugf("ssh: session close: %v", err)
}
}
var firstErr error
if sshClient != nil {
if err := sshClient.Close(); err != nil {
firstErr = err
}
}
s.notifyClose("closed by client")
return firstErr
}
func (s *SSHClient) buildAuth(cfg *profilemanager.Config, engine *internal.Engine,
serverType detection.ServerType, password string) ([]gossh.AuthMethod, gossh.HostKeyCallback, error) {
switch serverType {
case detection.ServerTypeNetBirdJWT:
token, err := s.requestJWTToken(cfg)
if err != nil {
return nil, nil, fmt.Errorf("jwt: %w", err)
}
auths := []gossh.AuthMethod{gossh.Password(token)}
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
case detection.ServerTypeNetBirdNoJWT:
if cfg.SSHKey == "" {
return nil, nil, errors.New("no NetBird SSH key available")
}
signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey))
if err != nil {
return nil, nil, fmt.Errorf("parse netbird ssh key: %w", err)
}
auths := []gossh.AuthMethod{gossh.PublicKeys(signer)}
return auths, nbssh.CreateHostKeyCallback(&engineHostKeyVerifier{engine: engine}), nil
default: // regular SSH
var auths []gossh.AuthMethod
if cfg.SSHKey != "" {
if signer, err := gossh.ParsePrivateKey([]byte(cfg.SSHKey)); err == nil {
auths = append(auths, gossh.PublicKeys(signer))
} else {
log.Debugf("ssh: parse netbird key for regular auth: %v", err)
}
}
if password != "" {
pw := password
auths = append(auths, gossh.Password(pw))
auths = append(auths, gossh.KeyboardInteractive(func(_, _ string, questions []string, _ []bool) ([]string, error) {
answers := make([]string, len(questions))
for i := range questions {
answers[i] = pw
}
return answers, nil
}))
}
if len(auths) == 0 {
return nil, nil, errors.New("no auth method available: provide a password or configure NetBird SSH key")
}
return auths, gossh.InsecureIgnoreHostKey(), nil // nolint:gosec // TOFU not yet implemented
}
}
func (s *SSHClient) requestJWTToken(cfg *profilemanager.Config) (string, error) {
s.mu.Lock()
urlOpener := s.urlOpener
s.mu.Unlock()
if urlOpener == nil {
return "", errors.New("URL opener not configured for JWT auth")
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Minute)
defer cancel()
flow, err := auth.NewOAuthFlow(ctx, cfg, false, true, profilemanager.GetLoginHint())
if err != nil {
return "", fmt.Errorf("create oauth flow: %w", err)
}
flowInfo, err := flow.RequestAuthInfo(ctx)
if err != nil {
return "", fmt.Errorf("request auth info: %w", err)
}
go urlOpener.Open(flowInfo.VerificationURIComplete, flowInfo.UserCode)
tokenInfo, err := flow.WaitToken(ctx, flowInfo)
if err != nil {
return "", fmt.Errorf("wait for token: %w", err)
}
token := tokenInfo.GetTokenToUse()
if token == "" {
return "", errors.New("empty token returned by IdP")
}
return token, nil
}
func (s *SSHClient) dialAndHandshake(host string, port int, clientConfig *gossh.ClientConfig) error {
addr := net.JoinHostPort(host, strconv.Itoa(port))
log.Infof("SSH: connecting to %s as %s", addr, clientConfig.User)
ctx, cancel := context.WithTimeout(context.Background(), sshDialTimeout)
defer cancel()
var dialer net.Dialer
conn, err := dialer.DialContext(ctx, "tcp", addr)
if err != nil {
return fmt.Errorf("dial %s: %w", addr, err)
}
sshConn, chans, reqs, err := gossh.NewClientConn(conn, addr, clientConfig)
if err != nil {
if cerr := conn.Close(); cerr != nil {
log.Debugf("ssh: close after handshake error: %v", cerr)
}
return fmt.Errorf("ssh handshake: %w", err)
}
s.mu.Lock()
s.sshClient = gossh.NewClient(sshConn, chans, reqs)
listener := s.listener
s.mu.Unlock()
log.Infof("SSH: connected to %s", addr)
if listener != nil {
listener.OnConnected()
}
return nil
}
func (s *SSHClient) readLoop(r io.Reader, name string) {
buf := make([]byte, 4096)
for {
n, err := r.Read(buf)
if n > 0 {
s.mu.Lock()
listener := s.listener
s.mu.Unlock()
if listener != nil {
chunk := make([]byte, n)
copy(chunk, buf[:n])
listener.OnData(chunk)
}
}
if err != nil {
if !errors.Is(err, io.EOF) {
log.Debugf("ssh %s read: %v", name, err)
}
s.notifyClose(err.Error())
return
}
}
}
func (s *SSHClient) notifyClose(reason string) {
s.mu.Lock()
if s.closed {
s.mu.Unlock()
return
}
s.closed = true
listener := s.listener
s.mu.Unlock()
if listener != nil {
listener.OnClose(reason)
}
}
// engineHostKeyVerifier adapts *internal.Engine to nbssh.HostKeyVerifier.
type engineHostKeyVerifier struct {
engine *internal.Engine
}
func (v *engineHostKeyVerifier) VerifySSHHostKey(peerAddress string, presented []byte) error {
storedKey, found := v.engine.GetPeerSSHKey(peerAddress)
if !found {
return nbssh.ErrPeerNotFound
}
return nbssh.VerifyHostKey(storedKey, presented, peerAddress)
}
func closeQuiet(c io.Closer, label string) {
if c == nil {
return
}
if err := c.Close(); err != nil && !errors.Is(err, io.EOF) {
log.Debugf("ssh: close %s: %v", label, err)
}
}
func detectServerType(host string, port int) detection.ServerType {
ctx, cancel := context.WithTimeout(context.Background(), sshDetectionTimeout)
defer cancel()
dialer := &net.Dialer{}
serverType, err := detection.DetectSSHServerType(ctx, dialer, host, port)
if err != nil {
log.Debugf("ssh: server detection for %s:%d failed: %v (assuming regular SSH)", host, port, err)
return detection.ServerTypeRegular
}
return serverType
}

View File

@@ -217,6 +217,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error
@@ -224,15 +225,7 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
}
// Release w.mu before calling w.tun.Close(): the underlying
// wireguard-go device.Close() waits for its send/receive goroutines
// to drain. Some of those goroutines re-enter WGIface methods that
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
// holding the mutex here would deadlock the shutdown path.
tun := w.tun
w.mu.Unlock()
if err := tun.Close(); err != nil {
if err := w.tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}

View File

@@ -1,113 +0,0 @@
//go:build !android
package iface
import (
"errors"
"sync"
"testing"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// fakeTunDevice implements WGTunDevice and lets the test control when
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
// until its goroutines drain. Some of those goroutines (e.g. the packet
// filter DNS hook in client/internal/dns) call back into WGIface, so if
// WGIface.Close() held w.mu across tun.Close() the shutdown would
// deadlock.
type fakeTunDevice struct {
closeStarted chan struct{}
unblockClose chan struct{}
}
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
func (f *fakeTunDevice) Close() error {
close(f.closeStarted)
<-f.unblockClose
return nil
}
type fakeProxyFactory struct{}
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
func (fakeProxyFactory) Free() error { return nil }
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
// that surfaces as a macOS test-timeout in
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
// while waiting for the wireguard-go device goroutines to finish, and
// one of those goroutines (the DNS filter hook) calls back into
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
// the lock before tun.Close() returns control.
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
tun := &fakeTunDevice{
closeStarted: make(chan struct{}),
unblockClose: make(chan struct{}),
}
w := &WGIface{
tun: tun,
wgProxyFactory: fakeProxyFactory{},
}
closeDone := make(chan error, 1)
go func() {
closeDone <- w.Close()
}()
select {
case <-tun.closeStarted:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
t.Fatal("tun.Close() was never invoked")
}
// Simulate the WireGuard read goroutine calling back into WGIface
// via the packet filter's DNS hook. If Close() still held w.mu
// during tun.Close(), this would block until the test timeout.
getDeviceDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = w.GetDevice()
close(getDeviceDone)
}()
select {
case <-getDeviceDone:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
wg.Wait()
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
}
close(tun.unblockClose)
select {
case <-closeDone:
case <-time.After(2 * time.Second):
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
}
}

View File

@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
}
if u.address.Network.Contains(a) {
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}

View File

@@ -94,7 +94,6 @@ func (c *ConnectClient) RunOnAndroid(
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
stateFilePath string,
cacheDir string,
) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
@@ -104,7 +103,6 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, "")
}
@@ -340,7 +338,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Error(err)
return wrapErr(err)
}
engineConfig.TempDir = mobileDependency.TempDir
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager)

View File

@@ -16,6 +16,7 @@ import (
"path/filepath"
"runtime"
"runtime/pprof"
"slices"
"sort"
"strings"
"time"
@@ -30,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const readmeContent = `Netbird debug bundle
@@ -232,7 +234,6 @@ type BundleGenerator struct {
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
@@ -255,7 +256,6 @@ type GeneratorDependencies struct {
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter
@@ -275,7 +275,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
@@ -288,7 +287,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
// Generate creates a debug bundle and returns the location.
func (g *BundleGenerator) Generate() (resp string, err error) {
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return "", fmt.Errorf("create zip file: %w", err)
}
@@ -374,8 +373,15 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if err := g.addPlatformLog(); err != nil {
log.Errorf("failed to add logs to debug bundle: %v", err)
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {

View File

@@ -1,41 +0,0 @@
//go:build android
package debug
import (
"fmt"
"io"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (g *BundleGenerator) addPlatformLog() error {
cmd := exec.Command("/system/bin/logcat", "-d")
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("logcat stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start logcat: %w", err)
}
var logReader io.Reader = stdout
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(stdout, pw, g.anonymizer)
}
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
return fmt.Errorf("add logcat to zip: %w", err)
}
if err := cmd.Wait(); err != nil {
return fmt.Errorf("wait logcat: %w", err)
}
log.Debug("added logcat output to debug bundle")
return nil
}

View File

@@ -1,25 +0,0 @@
//go:build !android
package debug
import (
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
func (g *BundleGenerator) addPlatformLog() error {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
return err
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
return err
}
return nil
}

View File

@@ -140,7 +140,6 @@ type EngineConfig struct {
ProfileConfig *profilemanager.Config
LogPath string
TempDir string
}
// EngineServices holds the external service dependencies required by the Engine.
@@ -1096,7 +1095,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
RefreshStatus: func() {
e.RunHealthProbes(true)

View File

@@ -22,8 +22,4 @@ type MobileDependency struct {
DnsManager dns.IosDnsManager
FileDescriptor int32
StateFilePath string
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
// On Android, this should be set to the app's cache directory.
TempDir string
}

View File

@@ -1,10 +0,0 @@
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
package systemops
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
// always fall through to the ref-counter exclusion-route path; these stubs
// exist only so systemops_unix.go compiles.
func (r *SysOps) setupAdvancedRouting() error { return nil }
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
func (r *SysOps) flushPlatformExtras() error { return nil }

View File

@@ -1,241 +0,0 @@
//go:build darwin && !ios
package systemops
import (
"errors"
"fmt"
"net/netip"
"os"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
// scopedRouteBudget bounds retries for the scoped default route. Installing or
// deleting it matters enough that we're willing to spend longer waiting for the
// kernel reply than for per-prefix exclusion routes.
const scopedRouteBudget = 5 * time.Second
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
// resolve gateway'd destinations while the VPN's split default owns the
// unscoped table.
//
// Timing note: this runs during routeManager.Init, which happens before the
// VPN interface is created and before any peer routes propagate. The initial
// mgmt / signal / relay TCP dials always fire before this runs, so those
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
// lookup, which at that point correctly picks the physical default. Those
// already-established TCP flows keep their originally-selected interface for
// their lifetime on Darwin because the kernel caches the egress route
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
// afterwards does not migrate them since the original en0 default stays in
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
func (r *SysOps) setupAdvancedRouting() error {
// Drop any previously-cached egress interface before reinstalling. On a
// refresh, a family that no longer resolves would otherwise keep the stale
// binding, causing new sockets to scope to an interface without a matching
// scoped default.
nbnet.ClearBoundInterfaces()
if err := r.flushScopedDefaults(); err != nil {
log.Warnf("flush residual scoped defaults: %v", err)
}
var merr *multierror.Error
installed := 0
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
ok, err := r.installScopedDefaultFor(unspec)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if ok {
installed++
}
}
if installed == 0 && merr != nil {
return nberrors.FormatErrorOrNil(merr)
}
if merr != nil {
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
}
return nil
}
// installScopedDefaultFor resolves the physical default nexthop for the given
// address family, installs a scoped default via it, and caches the iface for
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
nexthop, err := GetNextHop(unspec)
if err != nil {
if errors.Is(err, vars.ErrRouteNotFound) {
return false, nil
}
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
}
if nexthop.Intf == nil {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
}
if err := r.addScopedDefault(unspec, nexthop); err != nil {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
af := unix.AF_INET
if unspec.Is6() {
af = unix.AF_INET6
}
nbnet.SetBoundInterface(af, nexthop.Intf)
via := "point-to-point"
if nexthop.IP.IsValid() {
via = nexthop.IP.String()
}
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
return true, nil
}
func (r *SysOps) cleanupAdvancedRouting() error {
nbnet.ClearBoundInterfaces()
return r.flushScopedDefaults()
}
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
// removed on the next boot regardless of whether a profile is brought up.
func (r *SysOps) flushPlatformExtras() error {
return r.flushScopedDefaults()
}
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
// Safe to call at startup to clear residual entries from a prior session.
func (r *SysOps) flushScopedDefaults() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
removed := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
continue
}
info, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("skip scoped flush: %v", err)
continue
}
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
continue
}
if err := r.deleteScopedRoute(rtMsg); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
info.Dst, rtMsg.Index, err))
continue
}
removed++
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
}
if removed > 0 {
log.Infof("flushed %d residual scoped default route(s)", removed)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
}
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
// Preserve identifying flags from the stored route (including RTF_GATEWAY
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
del := &route.RouteMessage{
Type: unix.RTM_DELETE,
Flags: rtMsg.Flags & keep,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Index: rtMsg.Index,
Addrs: rtMsg.Addrs,
}
return r.writeRouteMessage(del, scopedRouteBudget)
}
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
msg := &route.RouteMessage{
Type: action,
Flags: flags,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
Index: nexthop.Intf.Index,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
dst, err := addrToRouteAddr(unspec)
if err != nil {
return fmt.Errorf("build destination: %w", err)
}
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
if err != nil {
return fmt.Errorf("build netmask: %w", err)
}
addrs[unix.RTAX_DST] = dst
addrs[unix.RTAX_NETMASK] = mask
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return fmt.Errorf("build gateway: %w", err)
}
addrs[unix.RTAX_GATEWAY] = gw
} else {
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return r.writeRouteMessage(msg, scopedRouteBudget)
}
func afOf(a netip.Addr) string {
if a.Is4() {
return "IPv4"
}
return "IPv6"
}

View File

@@ -0,0 +1,125 @@
//go:build darwin && !ios
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/client/net"
)
// TestAfOf verifies that afOf returns the correct string for each address family.
func TestAfOf(t *testing.T) {
tests := []struct {
name string
addr netip.Addr
want string
}{
{
name: "IPv4 unspecified",
addr: netip.IPv4Unspecified(),
want: "IPv4",
},
{
name: "IPv4 private",
addr: netip.MustParseAddr("10.0.0.1"),
want: "IPv4",
},
{
name: "IPv4 loopback",
addr: netip.MustParseAddr("127.0.0.1"),
want: "IPv4",
},
{
name: "IPv6 unspecified",
addr: netip.IPv6Unspecified(),
want: "IPv6",
},
{
name: "IPv6 loopback",
addr: netip.MustParseAddr("::1"),
want: "IPv6",
},
{
name: "IPv6 unicast",
addr: netip.MustParseAddr("2001:db8::1"),
want: "IPv6",
},
{
name: "IPv6 link-local",
addr: netip.MustParseAddr("fe80::1"),
want: "IPv6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, afOf(tt.addr))
})
}
}
// TestIsAddrRouted_AdvancedRoutingBypassesTunnelLookup verifies that when
// AdvancedRouting is active, IsAddrRouted immediately returns (false, zero)
// regardless of the provided vpn routes, because the WG socket is bound to
// the physical interface via IP_BOUND_IF and bypasses the main routing table.
func TestIsAddrRouted_AdvancedRoutingBypassesTunnelLookup(t *testing.T) {
// On darwin, AdvancedRouting returns true unless overridden.
// Ensure we reset the state after the test.
t.Setenv("NB_USE_LEGACY_ROUTING", "false")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
nbnet.Init()
require.True(t, nbnet.AdvancedRouting(), "test requires advanced routing to be active on darwin")
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("0.0.0.0/0"),
}
tests := []struct {
name string
addr netip.Addr
}{
{"IPv4 in VPN route", netip.MustParseAddr("10.0.0.1")},
{"IPv4 in narrow VPN route", netip.MustParseAddr("192.168.1.100")},
{"IPv4 default route covered", netip.MustParseAddr("8.8.8.8")},
{"IPv6 in VPN route", netip.MustParseAddr("2001:db8::1")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
routed, prefix := IsAddrRouted(tt.addr, vpnRoutes)
assert.False(t, routed, "should not be marked as routed via VPN when advanced routing is active")
assert.Equal(t, netip.Prefix{}, prefix, "matched prefix should be zero when advanced routing is active")
})
}
}
// TestIsAddrRouted_LegacyModeFallsThroughToTable verifies that when
// NB_USE_LEGACY_ROUTING=true disables advanced routing, IsAddrRouted
// performs the normal VPN-route vs local-route comparison.
func TestIsAddrRouted_LegacyModeFallsThroughToTable(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
nbnet.Init()
require.False(t, nbnet.AdvancedRouting(), "test requires advanced routing to be disabled")
// Use an address that is very unlikely to exist in the host routing table
// as a local route, so the VPN route wins.
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("198.51.100.0/24"), // TEST-NET-2 not in normal routing tables
}
addr := netip.MustParseAddr("198.51.100.1")
routed, _ := IsAddrRouted(addr, vpnRoutes)
// We cannot assert a specific outcome because it depends on the host's
// routing table, but we CAN assert that the call did not panic and returned
// a consistent pair.
_ = routed
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/net/hooks"
)
@@ -32,6 +31,8 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{})
@@ -396,16 +397,12 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
if nbnet.AdvancedRouting() {
return false, netip.Prefix{}
}
localRoutes, err := GetRoutesFromTable()
localRoutes, err := hasSeparateRouting()
if err != nil {
log.Errorf("Failed to get routes: %v", err)
if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{}
}

View File

@@ -0,0 +1,140 @@
//go:build !android && !ios
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
nbnet "github.com/netbirdio/netbird/client/net"
)
// withLegacyRouting forcibly puts the nbnet package into legacy (non-advanced)
// routing mode for the duration of the test, then restores the previous value.
func withLegacyRouting(t *testing.T) {
t.Helper()
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
nbnet.Init()
t.Cleanup(func() {
// After the test, re-initialise with no overrides so that subsequent
// tests start with a clean state.
nbnet.Init()
})
}
// withAdvancedRouting attempts to enable advanced routing for the test.
// On platforms where Init() cannot produce AdvancedRouting()=true (e.g. Linux
// without root), the calling test is skipped.
func withAdvancedRouting(t *testing.T) {
t.Helper()
t.Setenv("NB_USE_LEGACY_ROUTING", "false")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
nbnet.Init()
t.Cleanup(func() {
nbnet.Init()
})
if !nbnet.AdvancedRouting() {
t.Skip("advanced routing not available in this environment (need root or darwin)")
}
}
// TestIsAddrRouted_AdvancedRoutingShortCircuit verifies that when
// AdvancedRouting() returns true, IsAddrRouted immediately returns
// (false, zero prefix) regardless of any VPN routes, because the WG socket is
// bound directly to the physical interface and bypasses the kernel routing table.
func TestIsAddrRouted_AdvancedRoutingShortCircuit(t *testing.T) {
withAdvancedRouting(t)
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("::/0"),
}
tests := []struct {
name string
addr string
}{
{"IPv4 matched by 10/8", "10.0.0.1"},
{"IPv4 matched by 192.168/16", "192.168.1.1"},
{"IPv4 caught by default", "8.8.8.8"},
{"IPv6 caught by default", "2001:db8::1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := netip.MustParseAddr(tt.addr)
routed, prefix := IsAddrRouted(addr, vpnRoutes)
assert.False(t, routed, "advanced routing must short-circuit VPN route check")
assert.Equal(t, netip.Prefix{}, prefix, "returned prefix must be zero under advanced routing")
})
}
}
// TestIsAddrRouted_AdvancedRouting_EmptyRoutes verifies the short-circuit with
// an empty vpnRoutes slice — the result must still be (false, zero).
func TestIsAddrRouted_AdvancedRouting_EmptyRoutes(t *testing.T) {
withAdvancedRouting(t)
routed, prefix := IsAddrRouted(netip.MustParseAddr("10.0.0.1"), nil)
assert.False(t, routed)
assert.Equal(t, netip.Prefix{}, prefix)
}
// TestIsAddrRouted_LegacyMode_NonVPNAddress verifies that under legacy routing
// an address that is not covered by any VPN prefix returns false.
func TestIsAddrRouted_LegacyMode_NonVPNAddress(t *testing.T) {
withLegacyRouting(t)
// 198.51.100.x (TEST-NET-2) is not normally in any routing table.
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
}
addr := netip.MustParseAddr("198.51.100.1")
routed, _ := IsAddrRouted(addr, vpnRoutes)
assert.False(t, routed, "address not in VPN routes should not be marked as VPN-routed")
}
// TestIsAddrRouted_LegacyMode_EmptyVPNRoutes verifies that an empty vpnRoutes
// slice always yields (false, zero prefix) even under legacy routing.
func TestIsAddrRouted_LegacyMode_EmptyVPNRoutes(t *testing.T) {
withLegacyRouting(t)
routed, prefix := IsAddrRouted(netip.MustParseAddr("10.0.0.1"), nil)
assert.False(t, routed)
assert.Equal(t, netip.Prefix{}, prefix)
}
// TestIsAddrRouted_AdvancedVsLegacy_ContrastiveBehaviour documents the
// contract difference between the two modes: with a VPN default route and an
// address that matches it, legacy mode may mark it as VPN-routed while advanced
// mode must never do so.
func TestIsAddrRouted_AdvancedVsLegacy_ContrastiveBehaviour(t *testing.T) {
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("0.0.0.0/0"),
}
addr := netip.MustParseAddr("8.8.8.8")
// --- advanced routing: must always return false ---
t.Run("advanced routing", func(t *testing.T) {
withAdvancedRouting(t)
routed, prefix := IsAddrRouted(addr, vpnRoutes)
assert.False(t, routed, "advanced routing must bypass VPN route lookup")
assert.Equal(t, netip.Prefix{}, prefix)
})
// --- legacy routing: delegates to kernel table check, does not panic ---
t.Run("legacy routing", func(t *testing.T) {
withLegacyRouting(t)
// We don't assert true/false here because it depends on the host
// routing table, but the call must not panic and must return
// a valid (bool, prefix) pair.
routed, prefix := IsAddrRouted(addr, vpnRoutes)
t.Logf("legacy IsAddrRouted(%s, %v) = (%v, %v)", addr, vpnRoutes, routed, prefix)
})
}

View File

@@ -22,6 +22,10 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil

View File

@@ -894,6 +894,13 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6
}
func hasSeparateRouting() ([]netip.Prefix, error) {
if !nbnet.AdvancedRouting() {
return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}
func isOpErr(err error) bool {
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {

View File

@@ -48,6 +48,10 @@ func EnableIPForwarding() error {
return nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return GetRoutesFromTable()
}
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
func GetIPRules() ([]IPRule, error) {
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)

View File

@@ -25,9 +25,6 @@ import (
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
// routeBudget bounds retries for per-prefix exclusion route programming.
routeBudget = 1 * time.Second
)
var routeProtoFlag int
@@ -44,42 +41,26 @@ func init() {
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.setupAdvancedRouting()
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.cleanupAdvancedRouting()
}
return r.cleanupRefCounter(stateManager)
}
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
// crashed prior session can't leave crud in the table.
func (r *SysOps) FlushMarkedRoutes() error {
var merr *multierror.Error
if err := r.flushPlatformExtras(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
}
rib, err := retryFetchRIB()
if err != nil {
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
@@ -136,12 +117,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return fmt.Errorf("invalid prefix: %s", prefix)
}
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return fmt.Errorf("build route message: %w", err)
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
if err := r.writeRouteMessage(msg, routeBudget); err != nil {
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
a := "add"
if action == unix.RTM_DELETE {
a = "remove"
@@ -151,91 +132,50 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return nil
}
// writeRouteMessage sends a route message over AF_ROUTE and waits for the
// kernel's matching reply, retrying transient failures until budget elapses.
// Callers do not need to manage sockets or seq numbers themselves.
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = budget
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
}
func routeMessageRoundtrip(msg *route.RouteMessage) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("close routing socket: %v", err)
}
}()
tv := unix.Timeval{Sec: 1}
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
}
// AF_ROUTE is a broadcast channel: every route socket on the host sees
// every RTM_* event. With concurrent route programming the default
// per-socket queue overflows and our own reply gets dropped.
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
log.Debugf("set SO_RCVBUF on route socket: %v", err)
}
bytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
}
if _, err = unix.Write(fd, bytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
return readRouteResponse(fd, msg.Type, msg.Seq)
}
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
// broadcast channel: interface up/down, third-party route changes and neighbor
// discovery events can all land between our write and read, so we must filter.
func readRouteResponse(fd, wantType, wantSeq int) error {
pid := int32(os.Getpid())
resp := make([]byte, 2048)
deadline := time.Now().Add(time.Second)
for {
if time.Now().After(deadline) {
// Transient: under concurrent pressure the kernel can drop our reply
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
}
n, err := unix.Read(fd, resp)
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
continue
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}
return backoff.Permanent(fmt.Errorf("read: %w", err))
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
}
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
continue
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
}
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
// uniquely identifies our own reply among broadcast traffic.
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
continue
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
if hdr.Errno != 0 {
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
@@ -243,7 +183,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action,
Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
}
@@ -282,6 +221,19 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {

View File

@@ -1,5 +0,0 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin
//go:build !linux && !windows
package net

24
client/net/env_android.go Normal file
View File

@@ -0,0 +1,24 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -0,0 +1,121 @@
//go:build (darwin && !ios) || windows
package net
import (
"testing"
"github.com/stretchr/testify/assert"
)
// resetAdvancedRoutingState resets the package-level advancedRoutingSupported variable
// and cleans up after the test.
func resetAdvancedRoutingState(t *testing.T) {
t.Helper()
orig := advancedRoutingSupported
t.Cleanup(func() { advancedRoutingSupported = orig })
}
// TestCheckAdvancedRoutingSupport_LegacyRoutingTrue verifies that setting
// NB_USE_LEGACY_ROUTING=true disables advanced routing.
func TestCheckAdvancedRoutingSupport_LegacyRoutingTrue(t *testing.T) {
t.Setenv(envUseLegacyRouting, "true")
assert.False(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_LegacyRoutingFalse verifies that
// NB_USE_LEGACY_ROUTING=false still allows advanced routing when netstack is off.
func TestCheckAdvancedRoutingSupport_LegacyRoutingFalse(t *testing.T) {
t.Setenv(envUseLegacyRouting, "false")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
assert.True(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_LegacyRoutingInvalid verifies that an invalid
// value for NB_USE_LEGACY_ROUTING is ignored (treated as false), so advanced
// routing remains enabled when netstack is off.
func TestCheckAdvancedRoutingSupport_LegacyRoutingInvalid(t *testing.T) {
t.Setenv(envUseLegacyRouting, "notabool")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
// The invalid value is ignored; the default (false) is kept, so advanced routing
// is not suppressed by the legacy-routing flag.
assert.True(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_NetstackEnabled verifies that netstack mode
// disables advanced routing.
func TestCheckAdvancedRoutingSupport_NetstackEnabled(t *testing.T) {
t.Setenv(envUseLegacyRouting, "false")
t.Setenv("NB_USE_NETSTACK_MODE", "true")
assert.False(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_NoEnvVars verifies that with no env overrides
// advanced routing is supported (the happy path).
func TestCheckAdvancedRoutingSupport_NoEnvVars(t *testing.T) {
// Unset both controlling variables so we hit the default path.
t.Setenv(envUseLegacyRouting, "")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
assert.True(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_LegacyRoutingEmptyString verifies that an
// empty NB_USE_LEGACY_ROUTING is treated as "not set" and does not disable
// advanced routing.
func TestCheckAdvancedRoutingSupport_LegacyRoutingEmptyString(t *testing.T) {
t.Setenv(envUseLegacyRouting, "")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
assert.True(t, checkAdvancedRoutingSupport())
}
// TestAdvancedRouting_ReflectsInit verifies that after calling Init() with
// NB_USE_LEGACY_ROUTING=true, AdvancedRouting() returns false.
func TestAdvancedRouting_ReflectsInit(t *testing.T) {
resetAdvancedRoutingState(t)
t.Setenv(envUseLegacyRouting, "true")
Init()
assert.False(t, AdvancedRouting(), "AdvancedRouting should return false after Init with legacy routing")
}
// TestAdvancedRouting_ReflectsInit_Advanced verifies that after calling Init()
// without legacy overrides, AdvancedRouting() returns true.
func TestAdvancedRouting_ReflectsInit_Advanced(t *testing.T) {
resetAdvancedRoutingState(t)
t.Setenv(envUseLegacyRouting, "false")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
Init()
assert.True(t, AdvancedRouting(), "AdvancedRouting should return true after Init without legacy overrides")
}
// TestSetAndGetVPNInterfaceName verifies SetVPNInterfaceName and GetVPNInterfaceName
// are consistent.
func TestSetAndGetVPNInterfaceName(t *testing.T) {
orig := GetVPNInterfaceName()
t.Cleanup(func() { SetVPNInterfaceName(orig) })
SetVPNInterfaceName("utun3")
assert.Equal(t, "utun3", GetVPNInterfaceName())
}
// TestSetVPNInterfaceName_Empty verifies that setting an empty name is accepted.
func TestSetVPNInterfaceName_Empty(t *testing.T) {
orig := GetVPNInterfaceName()
t.Cleanup(func() { SetVPNInterfaceName(orig) })
SetVPNInterfaceName("")
assert.Equal(t, "", GetVPNInterfaceName())
}
// TestSetVPNInterfaceName_OverwritesPrevious verifies that the second call wins.
func TestSetVPNInterfaceName_OverwritesPrevious(t *testing.T) {
orig := GetVPNInterfaceName()
t.Cleanup(func() { SetVPNInterfaceName(orig) })
SetVPNInterfaceName("utun1")
SetVPNInterfaceName("utun9")
assert.Equal(t, "utun9", GetVPNInterfaceName())
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin
//go:build !linux && !windows && !android
package net

View File

@@ -1,25 +0,0 @@
//go:build ios || android
package net
// Init initializes the network environment for mobile platforms.
func Init() {
// no-op on mobile: routing scope is owned by the VPN extension.
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
// owns the routing scope.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on mobile.
func SetVPNInterfaceName(string) {
// no-op on mobile: the VPN extension manages the interface.
}
// GetVPNInterfaceName returns an empty string on mobile.
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -0,0 +1,39 @@
//go:build ios || android
package net
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestAdvancedRouting_Mobile verifies that AdvancedRouting always returns true
// on mobile platforms.
func TestAdvancedRouting_Mobile(t *testing.T) {
assert.True(t, AdvancedRouting(), "AdvancedRouting must always be true on mobile")
}
// TestInit_Mobile verifies that Init is a no-op and does not panic.
func TestInit_Mobile(t *testing.T) {
// Should not panic.
Init()
// After Init, AdvancedRouting must still return true.
assert.True(t, AdvancedRouting())
}
// TestSetVPNInterfaceName_Mobile verifies that SetVPNInterfaceName is a no-op
// and does not panic.
func TestSetVPNInterfaceName_Mobile(t *testing.T) {
// Should not panic for any input.
SetVPNInterfaceName("utun0")
SetVPNInterfaceName("")
}
// TestGetVPNInterfaceName_Mobile verifies that GetVPNInterfaceName always
// returns an empty string on mobile.
func TestGetVPNInterfaceName_Mobile(t *testing.T) {
// Even after a SetVPNInterfaceName call (no-op), the getter returns "".
SetVPNInterfaceName("utun0")
assert.Equal(t, "", GetVPNInterfaceName(), "GetVPNInterfaceName must return empty string on mobile")
}

View File

@@ -1,4 +1,4 @@
//go:build (darwin && !ios) || windows
//go:build windows
package net
@@ -24,22 +24,17 @@ func Init() {
}
func checkAdvancedRoutingSupport() bool {
legacyRouting := false
var err error
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" {
parsed, err := strconv.ParseBool(val)
legacyRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
} else {
legacyRouting = parsed
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
}
}
if legacyRouting {
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
return false
}
if netstack.IsEnabled() {
log.Info("advanced routing disabled: netstack mode is enabled")
if legacyRouting || netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
return false
}

View File

@@ -1,5 +0,0 @@
package net
func (l *ListenerConfig) init() {
l.ListenConfig.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin
//go:build !linux && !windows
package net

View File

@@ -1,160 +0,0 @@
package net
import (
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
// dispatched by socket domain (AF_INET only) not by inp_vflag.
// boundIface holds the physical interface chosen at routing setup time. Sockets
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
// following the VPN's split default.
var (
boundIfaceMu sync.RWMutex
boundIface4 *net.Interface
boundIface6 *net.Interface
)
// SetBoundInterface records the egress interface for an address family. Called
// by the routemanager after a scoped default route has been installed.
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
func SetBoundInterface(af int, iface *net.Interface) {
if iface == nil {
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
return
}
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
switch af {
case unix.AF_INET:
boundIface4 = iface
case unix.AF_INET6:
boundIface6 = iface
default:
log.Warnf("SetBoundInterface: unsupported address family %d", af)
}
}
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
// routemanager during cleanup.
func ClearBoundInterfaces() {
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
boundIface4 = nil
boundIface6 = nil
}
// boundInterfaceFor returns the cached egress interface for a socket's address
// family, falling back to the other family if the preferred slot is empty.
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
// either setsockopt scopes the socket; preferring same-family still matters
// when v4 and v6 defaults egress different NICs.
func boundInterfaceFor(network, address string) *net.Interface {
if iface := zoneInterface(address); iface != nil {
return iface
}
boundIfaceMu.RLock()
defer boundIfaceMu.RUnlock()
primary, secondary := boundIface4, boundIface6
if isV6Network(network) {
primary, secondary = boundIface6, boundIface4
}
if primary != nil {
return primary
}
return secondary
}
func isV6Network(network string) bool {
return strings.HasSuffix(network, "6")
}
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
func zoneInterface(address string) *net.Interface {
if address == "" {
return nil
}
addr, err := netip.ParseAddrPort(address)
if err != nil {
a, err := netip.ParseAddr(address)
if err != nil {
return nil
}
addr = netip.AddrPortFrom(a, 0)
}
zone := addr.Addr().Zone()
if zone == "" {
return nil
}
if iface, err := net.InterfaceByName(zone); err == nil {
return iface
}
if idx, err := strconv.Atoi(zone); err == nil {
if iface, err := net.InterfaceByIndex(idx); err == nil {
return iface
}
}
return nil
}
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
// applyBoundIfToSocket binds the socket to the cached physical egress interface
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
iface := boundInterfaceFor(network, address)
if iface == nil {
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
return nil
}
isV6 := isV6Network(network)
var controlErr error
if err := c.Control(func(fd uintptr) {
if isV6 {
controlErr = setIPv6BoundIf(fd, iface)
} else {
controlErr = setIPv4BoundIf(fd, iface)
}
if controlErr == nil {
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
}
}); err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

View File

@@ -0,0 +1,286 @@
//go:build darwin && !ios
package net
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
// resetBoundIfaces clears global state before each test that manipulates it.
func resetBoundIfaces(t *testing.T) {
t.Helper()
ClearBoundInterfaces()
t.Cleanup(ClearBoundInterfaces)
}
// TestIsV6Network verifies the isV6Network helper correctly identifies v6 networks.
func TestIsV6Network(t *testing.T) {
tests := []struct {
network string
want bool
}{
{"tcp6", true},
{"udp6", true},
{"ip6", true},
{"tcp", false},
{"udp", false},
{"tcp4", false},
{"udp4", false},
{"ip4", false},
{"ip", false},
{"", false},
// Arbitrary suffix-6 strings
{"unix6", true},
{"custom6", true},
}
for _, tt := range tests {
t.Run(tt.network, func(t *testing.T) {
assert.Equal(t, tt.want, isV6Network(tt.network))
})
}
}
// TestSetBoundInterface_AFINET sets boundIface4 via AF_INET.
func TestSetBoundInterface_AFINET(t *testing.T) {
resetBoundIfaces(t)
iface := &net.Interface{Index: 5, Name: "en0"}
SetBoundInterface(unix.AF_INET, iface)
boundIfaceMu.RLock()
got := boundIface4
boundIfaceMu.RUnlock()
require.NotNil(t, got)
assert.Equal(t, 5, got.Index)
assert.Equal(t, "en0", got.Name)
}
// TestSetBoundInterface_AFINET6 sets boundIface6 via AF_INET6.
func TestSetBoundInterface_AFINET6(t *testing.T) {
resetBoundIfaces(t)
iface := &net.Interface{Index: 7, Name: "utun0"}
SetBoundInterface(unix.AF_INET6, iface)
boundIfaceMu.RLock()
got := boundIface6
boundIfaceMu.RUnlock()
require.NotNil(t, got)
assert.Equal(t, 7, got.Index)
assert.Equal(t, "utun0", got.Name)
}
// TestSetBoundInterface_Nil verifies nil iface is rejected without panicking.
func TestSetBoundInterface_Nil(t *testing.T) {
resetBoundIfaces(t)
// Should not panic, just log a warning and leave existing values untouched.
iface := &net.Interface{Index: 1, Name: "en1"}
SetBoundInterface(unix.AF_INET, iface)
SetBoundInterface(unix.AF_INET, nil)
boundIfaceMu.RLock()
got := boundIface4
boundIfaceMu.RUnlock()
// Value must be unchanged after the rejected nil write.
require.NotNil(t, got)
assert.Equal(t, 1, got.Index)
}
// TestSetBoundInterface_UnknownAF verifies unknown address families are ignored.
func TestSetBoundInterface_UnknownAF(t *testing.T) {
resetBoundIfaces(t)
iface := &net.Interface{Index: 3, Name: "en2"}
// Use an address family that is not AF_INET or AF_INET6.
SetBoundInterface(99, iface)
boundIfaceMu.RLock()
v4, v6 := boundIface4, boundIface6
boundIfaceMu.RUnlock()
assert.Nil(t, v4, "unknown AF must not populate boundIface4")
assert.Nil(t, v6, "unknown AF must not populate boundIface6")
}
// TestClearBoundInterfaces clears both cached interfaces.
func TestClearBoundInterfaces(t *testing.T) {
iface4 := &net.Interface{Index: 1, Name: "en0"}
iface6 := &net.Interface{Index: 2, Name: "en0"}
SetBoundInterface(unix.AF_INET, iface4)
SetBoundInterface(unix.AF_INET6, iface6)
ClearBoundInterfaces()
boundIfaceMu.RLock()
v4, v6 := boundIface4, boundIface6
boundIfaceMu.RUnlock()
assert.Nil(t, v4, "boundIface4 must be nil after clear")
assert.Nil(t, v6, "boundIface6 must be nil after clear")
}
// TestClearBoundInterfaces_Idempotent verifies clearing twice does not panic.
func TestClearBoundInterfaces_Idempotent(t *testing.T) {
ClearBoundInterfaces()
ClearBoundInterfaces()
boundIfaceMu.RLock()
v4, v6 := boundIface4, boundIface6
boundIfaceMu.RUnlock()
assert.Nil(t, v4)
assert.Nil(t, v6)
}
// TestBoundInterfaceFor_PreferSameFamily verifies v4 iface returned for "tcp" and
// v6 iface returned for "tcp6" when both slots are populated.
func TestBoundInterfaceFor_PreferSameFamily(t *testing.T) {
resetBoundIfaces(t)
en0 := &net.Interface{Index: 1, Name: "en0"}
en1 := &net.Interface{Index: 2, Name: "en1"}
SetBoundInterface(unix.AF_INET, en0)
SetBoundInterface(unix.AF_INET6, en1)
got4 := boundInterfaceFor("tcp", "1.2.3.4:80")
require.NotNil(t, got4)
assert.Equal(t, "en0", got4.Name, "tcp should prefer v4 interface")
got6 := boundInterfaceFor("tcp6", "[::1]:80")
require.NotNil(t, got6)
assert.Equal(t, "en1", got6.Name, "tcp6 should prefer v6 interface")
}
// TestBoundInterfaceFor_FallbackToOtherFamily returns the other family's iface
// when the preferred slot is empty.
func TestBoundInterfaceFor_FallbackToOtherFamily(t *testing.T) {
resetBoundIfaces(t)
// Only v4 populated.
en0 := &net.Interface{Index: 1, Name: "en0"}
SetBoundInterface(unix.AF_INET, en0)
// Asking for v6 should fall back to en0.
got := boundInterfaceFor("tcp6", "[::1]:80")
require.NotNil(t, got)
assert.Equal(t, "en0", got.Name)
}
// TestBoundInterfaceFor_BothEmpty returns nil when both slots are empty.
func TestBoundInterfaceFor_BothEmpty(t *testing.T) {
resetBoundIfaces(t)
got := boundInterfaceFor("tcp", "1.2.3.4:80")
assert.Nil(t, got)
got6 := boundInterfaceFor("tcp6", "[::1]:80")
assert.Nil(t, got6)
}
// TestZoneInterface_Empty returns nil for empty address.
func TestZoneInterface_Empty(t *testing.T) {
iface := zoneInterface("")
assert.Nil(t, iface)
}
// TestZoneInterface_NoZone returns nil when address has no zone.
func TestZoneInterface_NoZone(t *testing.T) {
// Regular IPv4 address with port — no zone identifier.
iface := zoneInterface("192.168.1.1:80")
assert.Nil(t, iface)
// Regular IPv6 address with port — no zone identifier.
iface = zoneInterface("[2001:db8::1]:80")
assert.Nil(t, iface)
// Plain IPv6 address without port or zone.
iface = zoneInterface("2001:db8::1")
assert.Nil(t, iface)
}
// TestZoneInterface_InvalidAddress returns nil for completely invalid strings.
func TestZoneInterface_InvalidAddress(t *testing.T) {
iface := zoneInterface("not-an-address")
assert.Nil(t, iface)
iface = zoneInterface("::::")
assert.Nil(t, iface)
}
// TestZoneInterface_NonExistentZoneName returns nil for a zone name that does
// not correspond to a real interface on the host.
func TestZoneInterface_NonExistentZoneName(t *testing.T) {
// Use an interface name that is very unlikely to exist.
iface := zoneInterface("fe80::1%nonexistentiface99999")
assert.Nil(t, iface)
}
// TestZoneInterface_NonExistentZoneIndex returns nil for a zone expressed as an
// integer index that is not in use.
func TestZoneInterface_NonExistentZoneIndex(t *testing.T) {
// Interface index 999999 should not exist on any test machine.
iface := zoneInterface("fe80::1%999999")
assert.Nil(t, iface)
}
// TestBoundInterfaceFor_SetThenClear verifies that clearing state causes
// boundInterfaceFor to return nil afterwards.
func TestBoundInterfaceFor_SetThenClear(t *testing.T) {
resetBoundIfaces(t)
en0 := &net.Interface{Index: 1, Name: "en0"}
SetBoundInterface(unix.AF_INET, en0)
got := boundInterfaceFor("tcp", "1.2.3.4:80")
require.NotNil(t, got, "should return iface while set")
ClearBoundInterfaces()
got = boundInterfaceFor("tcp", "1.2.3.4:80")
assert.Nil(t, got, "should return nil after clear")
}
// TestSetBoundInterface_OverwritesPreviousValue verifies that calling
// SetBoundInterface again updates the stored pointer.
func TestSetBoundInterface_OverwritesPreviousValue(t *testing.T) {
resetBoundIfaces(t)
first := &net.Interface{Index: 1, Name: "en0"}
second := &net.Interface{Index: 3, Name: "en1"}
SetBoundInterface(unix.AF_INET, first)
SetBoundInterface(unix.AF_INET, second)
boundIfaceMu.RLock()
got := boundIface4
boundIfaceMu.RUnlock()
require.NotNil(t, got)
assert.Equal(t, "en1", got.Name, "second call should overwrite first")
}
// TestBoundInterfaceFor_OnlyV6Populated returns v6 iface for v4 network
// when only v6 slot is filled.
func TestBoundInterfaceFor_OnlyV6Populated(t *testing.T) {
resetBoundIfaces(t)
en1 := &net.Interface{Index: 2, Name: "en1"}
SetBoundInterface(unix.AF_INET6, en1)
// v4 network, v4 slot empty → should fall back to v6 slot.
got := boundInterfaceFor("tcp", "1.2.3.4:80")
require.NotNil(t, got)
assert.Equal(t, "en1", got.Name)
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)
@@ -137,8 +138,10 @@ func restoreResidualState(ctx context.Context, statePath string) error {
}
// clean up any remaining routes independently of the state file
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)

View File

@@ -187,23 +187,24 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostList := strings.Join(deduplicatedPatterns, ",")
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
}

View File

@@ -116,37 +116,6 @@ func TestManager_PeerLimit(t *testing.T) {
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
}
func TestManager_MatchHostFormat(t *testing.T) {
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
peers := []PeerSSHInfo{
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
content, err := os.ReadFile(configPath)
require.NoError(t, err)
configStr := string(content)
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
"should use Match host with comma-separated patterns")
}
func TestManager_ForcedSSHConfig(t *testing.T) {
// Set force environment variable
t.Setenv(EnvForceSSHConfig, "true")

View File

@@ -2,6 +2,7 @@ package system
import (
"context"
"net"
"net/netip"
"strings"
@@ -144,6 +145,59 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
return v
}
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
ipNet, ok := address.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
netAddr := NetworkAddress{
NetIP: netip.MustParsePrefix(ipNet.String()),
Mac: iface.HardwareAddr.String(),
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))

View File

@@ -2,8 +2,6 @@ package system
import (
"context"
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
@@ -44,66 +42,6 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// networkAddresses returns the list of network addresses on iOS.
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
// check that other platforms use and filter out link-local addresses as they
// are not useful for posture checks.
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: ""}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil

View File

@@ -1,66 +0,0 @@
//go:build !ios
package system
import (
"net"
"net/netip"
)
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
mac := iface.HardwareAddr.String()
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address, mac)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: mac}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}