Compare commits

..

1 Commits

Author SHA1 Message Date
coderabbitai[bot]
cb1eaf9e0d CodeRabbit Generated Unit Tests: Add unit tests for PR changes 2026-04-20 07:51:24 +00:00
157 changed files with 8024 additions and 7120 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.4"
SIGN_PIPE_VER: "v0.1.2"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -115,12 +115,6 @@ jobs:
release:
runs-on: ubuntu-latest-m
outputs:
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
env:
flags: ""
steps:
@@ -219,13 +213,10 @@ jobs:
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only)
id: tag_and_push_images
if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
run: |
set -euo pipefail
resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}"
@@ -234,17 +225,6 @@ jobs:
fi
}
ghcr_package_url() {
local image="$1" package encoded_package
package="${image#ghcr.io/}"
package="${package#*/}"
package="${package%%:*}"
encoded_package="${package//\//%2F}"
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
}
image_refs=()
tag_and_push() {
local src="$1" img_name tag dst
img_name="${src%%:*}"
@@ -253,56 +233,35 @@ jobs:
echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst"
docker push "$dst"
image_refs+=("$dst")
done
}
cat > /tmp/goreleaser-artifacts.json <<'JSON'
${{ steps.goreleaser.outputs.artifacts }}
JSON
export -f tag_and_push resolve_tags
mapfile -t src_images < <(
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
)
for src in "${src_images[@]}"; do
tag_and_push "$src"
done
{
echo "images_markdown<<EOF"
if [[ ${#image_refs[@]} -eq 0 ]]; then
echo "_No GHCR images were pushed._"
else
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
done
fi
echo "EOF"
} >> "$GITHUB_OUTPUT"
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
grep '^ghcr.io/' | while read -r SRC; do
tag_and_push "$SRC"
done
- name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 7
- name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 7
- name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 7
- name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@v4
with:
name: macos-packages
@@ -311,8 +270,6 @@ jobs:
release_ui:
runs-on: ubuntu-latest
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps:
- name: Parse semver string
id: semver_parser
@@ -403,7 +360,6 @@ jobs:
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@v4
with:
name: release-ui
@@ -412,8 +368,6 @@ jobs:
release_ui_darwin:
runs-on: macos-latest
outputs:
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
@@ -448,110 +402,12 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/
retention-days: 3
comment_release_artifacts:
name: Comment release artifacts
runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin]
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@v7
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const marker = '<!-- netbird-release-artifacts -->';
const { owner, repo } = context.repo;
const issue_number = context.payload.pull_request.number;
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
const artifactCell = (url, result) => {
if (url) return `[Download](${url})`;
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
};
const artifacts = [
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
];
const artifactRows = artifacts
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
.join('\n');
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
const body = [
marker,
'## Release artifacts',
'',
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
'',
'| Artifact | Link |',
'| --- | --- |',
artifactRows,
'',
'### GHCR images (amd64)',
ghcrImages,
'',
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
].join('\n');
const comments = await github.paginate(github.rest.issues.listComments, {
owner,
repo,
issue_number,
per_page: 100,
});
const previous = comments.find(comment =>
comment.user?.type === 'Bot' && comment.body?.includes(marker)
);
if (previous) {
await github.rest.issues.updateComment({
owner,
repo,
comment_id: previous.id,
body,
});
core.info(`Updated release artifacts comment ${previous.id}`);
} else {
const { data } = await github.rest.issues.createComment({
owner,
repo,
issue_number,
body,
});
core.info(`Created release artifacts comment ${data.id}`);
}
trigger_signer:
runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin]

View File

@@ -9,8 +9,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short
# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref.
jobs:
trigger_sync_tag:
runs-on: ubuntu-latest
@@ -22,30 +20,4 @@ jobs:
ref: main
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_android_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger android-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/android-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_ios_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger ios-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/ios-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)

View File

@@ -8,7 +8,6 @@ import (
"os"
"slices"
"sync"
"time"
"golang.org/x/exp/maps"
@@ -16,7 +15,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -28,7 +26,6 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -71,30 +68,7 @@ type Client struct {
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
}
// NewClient instantiate a new Client
@@ -119,7 +93,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -151,9 +124,8 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -163,7 +135,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -186,9 +157,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// Stop the internal client and free the resources
@@ -203,12 +173,11 @@ func (c *Client) Stop() {
}
func (c *Client) RenewTun(fd int) error {
cc := c.getConnectClient()
if cc == nil {
if c.connectClient == nil {
return fmt.Errorf("engine not running")
}
e := cc.Engine()
e := c.connectClient.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
@@ -216,73 +185,6 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd)
}
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
@@ -312,13 +214,12 @@ func (c *Client) PeersList() *PeerInfoArray {
}
func (c *Client) Networks() *NetworkArray {
cc := c.getConnectClient()
if cc == nil {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := cc.Engine()
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
@@ -399,7 +300,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.getConnectClient()
client := c.connectClient
if client == nil {
return nil, fmt.Errorf("not connected")
}

View File

@@ -7,5 +7,4 @@ package android
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
CacheDir() string
}

View File

@@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}

View File

@@ -1,11 +0,0 @@
// Package firewalld integrates with the firewalld daemon so NetBird can place
// its wg interface into firewalld's "trusted" zone. This is required because
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
// versions, which returns EPERM to any other process that tries to insert
// rules into them. The workaround mirrors what Tailscale does: let firewalld
// itself add the accept rules to its own chains by trusting the interface.
package firewalld
// TrustedZone is the firewalld zone name used for interfaces whose traffic
// should bypass firewalld filtering.
const TrustedZone = "trusted"

View File

@@ -1,260 +0,0 @@
//go:build linux
package firewalld
import (
"context"
"errors"
"fmt"
"os/exec"
"strings"
"sync"
"time"
"github.com/godbus/dbus/v5"
log "github.com/sirupsen/logrus"
)
const (
dbusDest = "org.fedoraproject.FirewallD1"
dbusPath = "/org/fedoraproject/FirewallD1"
dbusRootIface = "org.fedoraproject.FirewallD1"
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
errZoneAlreadySet = "ZONE_ALREADY_SET"
errAlreadyEnabled = "ALREADY_ENABLED"
errUnknownIface = "UNKNOWN_INTERFACE"
errNotEnabled = "NOT_ENABLED"
// callTimeout bounds each individual DBus or firewall-cmd invocation.
// A fresh context is created for each call so a slow DBus probe can't
// exhaust the deadline before the firewall-cmd fallback gets to run.
callTimeout = 3 * time.Second
)
var (
errDBusUnavailable = errors.New("firewalld dbus unavailable")
// trustLogOnce ensures the "added to trusted zone" message is logged at
// Info level only for the first successful add per process; repeat adds
// from other init paths are quieter.
trustLogOnce sync.Once
parentCtxMu sync.RWMutex
parentCtx context.Context = context.Background()
)
// SetParentContext installs a parent context whose cancellation aborts any
// in-flight TrustInterface call. It does not affect UntrustInterface, which
// always uses a fresh Background-rooted timeout so cleanup can still run
// during engine shutdown when the engine context is already cancelled.
func SetParentContext(ctx context.Context) {
parentCtxMu.Lock()
parentCtx = ctx
parentCtxMu.Unlock()
}
func getParentContext() context.Context {
parentCtxMu.RLock()
defer parentCtxMu.RUnlock()
return parentCtx
}
// TrustInterface places iface into firewalld's trusted zone if firewalld is
// running. It is idempotent and best-effort: errors are returned so callers
// can log, but a non-running firewalld is not an error. Only the first
// successful call per process logs at Info. Respects the parent context set
// via SetParentContext so startup-time cancellation unblocks it.
func TrustInterface(iface string) error {
parent := getParentContext()
if !isRunning(parent) {
return nil
}
if err := addTrusted(parent, iface); err != nil {
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
}
trustLogOnce.Do(func() {
log.Infof("added %s to firewalld trusted zone", iface)
})
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
return nil
}
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
// during shutdown after the engine context has been cancelled.
func UntrustInterface(iface string) error {
if !isRunning(context.Background()) {
return nil
}
if err := removeTrusted(context.Background(), iface); err != nil {
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
}
return nil
}
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, callTimeout)
}
func isRunning(parent context.Context) bool {
ctx, cancel := newCallContext(parent)
ok, err := isRunningDBus(ctx)
cancel()
if err == nil {
return ok
}
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
ctx, cancel = newCallContext(parent)
defer cancel()
return isRunningCLI(ctx)
}
return false
}
func addTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := addDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return addCLI(ctx, iface)
}
func removeTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := removeDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return removeCLI(ctx, iface)
}
func isRunningDBus(ctx context.Context) (bool, error) {
conn, err := dbus.SystemBus()
if err != nil {
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
var zone string
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
}
return true, nil
}
func isRunningCLI(ctx context.Context) bool {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return false
}
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
}
func addDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errAlreadyEnabled) {
return nil
}
if dbusErrContains(call.Err, errZoneAlreadySet) {
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
if move.Err != nil {
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
}
return nil
}
return fmt.Errorf("firewalld addInterface: %w", call.Err)
}
func removeDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
return nil
}
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
}
func addCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
// --change-interface (no --permanent) binds the interface for the
// current runtime only; we do not want membership to persist across
// reboots because netbird re-asserts it on every startup.
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
).CombinedOutput()
if err != nil {
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
func removeCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
).CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
return nil
}
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
}
return nil
}
func dbusErrContains(err error, code string) bool {
if err == nil {
return false
}
var de dbus.Error
if errors.As(err, &de) {
for _, b := range de.Body {
if s, ok := b.(string); ok && strings.Contains(s, code) {
return true
}
}
}
return strings.Contains(err.Error(), code)
}

View File

@@ -1,49 +0,0 @@
//go:build linux
package firewalld
import (
"errors"
"testing"
"github.com/godbus/dbus/v5"
)
func TestDBusErrContains(t *testing.T) {
tests := []struct {
name string
err error
code string
want bool
}{
{"nil error", nil, errZoneAlreadySet, false},
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
{
"dbus.Error body match",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
errZoneAlreadySet,
true,
},
{
"dbus.Error body miss",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
errAlreadyEnabled,
false,
},
{
"dbus.Error non-string body falls back to Error()",
dbus.Error{Name: "x", Body: []any{123}},
"x",
true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := dbusErrContains(tc.err, tc.code)
if got != tc.want {
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
}
})
}
}

View File

@@ -1,25 +0,0 @@
//go:build !linux
package firewalld
import "context"
// SetParentContext is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func SetParentContext(context.Context) {
// intentionally empty: firewalld is a Linux-only daemon
}
// TrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func TrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func UntrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}

View File

@@ -12,7 +12,6 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -87,12 +86,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
}
// Trust after all fatal init steps so a later failure doesn't leave the
// interface in firewalld's trusted zone without a corresponding Close.
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
@@ -198,12 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
}
// Appending to merr intentionally blocks DeleteState below so ShutdownState
// stays persisted and the crash-recovery path retries firewalld cleanup.
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
@@ -230,11 +217,6 @@ func (m *Manager) AllowNetbird() error {
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}

View File

@@ -14,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -218,10 +217,6 @@ func (m *Manager) AllowNetbird() error {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
@@ -41,8 +40,6 @@ const (
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
@@ -136,10 +133,6 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
@@ -287,10 +280,6 @@ func (r *router) createContainers() error {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err)
}
@@ -1330,13 +1319,6 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false
}
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false

View File

@@ -3,9 +3,6 @@
package uspfilter
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/firewalld"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -19,9 +16,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)
}
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to untrust interface in firewalld: %v", err)
}
return nil
}
@@ -30,8 +24,5 @@ func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird()
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}

View File

@@ -9,7 +9,6 @@ import (
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
GetWGDevice() *wgdevice.Device

View File

@@ -31,20 +31,12 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct {
NameFunc func() string
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
func (i *IFaceMock) Name() string {
if i.NameFunc == nil {
return "wgtest"
}
return i.NameFunc()
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil {
return nil

View File

@@ -239,12 +239,8 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv6Count++
}
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
// routing-correctness checks above are the real assertions; the counts
// are a sanity bound to catch a totally silent path.
minDelivered := packetsPerFamily * 80 / 100
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {

View File

@@ -217,6 +217,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error
@@ -224,15 +225,7 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
}
// Release w.mu before calling w.tun.Close(): the underlying
// wireguard-go device.Close() waits for its send/receive goroutines
// to drain. Some of those goroutines re-enter WGIface methods that
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
// holding the mutex here would deadlock the shutdown path.
tun := w.tun
w.mu.Unlock()
if err := tun.Close(); err != nil {
if err := w.tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}

View File

@@ -1,113 +0,0 @@
//go:build !android
package iface
import (
"errors"
"sync"
"testing"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// fakeTunDevice implements WGTunDevice and lets the test control when
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
// until its goroutines drain. Some of those goroutines (e.g. the packet
// filter DNS hook in client/internal/dns) call back into WGIface, so if
// WGIface.Close() held w.mu across tun.Close() the shutdown would
// deadlock.
type fakeTunDevice struct {
closeStarted chan struct{}
unblockClose chan struct{}
}
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
func (f *fakeTunDevice) Close() error {
close(f.closeStarted)
<-f.unblockClose
return nil
}
type fakeProxyFactory struct{}
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
func (fakeProxyFactory) Free() error { return nil }
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
// that surfaces as a macOS test-timeout in
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
// while waiting for the wireguard-go device goroutines to finish, and
// one of those goroutines (the DNS filter hook) calls back into
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
// the lock before tun.Close() returns control.
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
tun := &fakeTunDevice{
closeStarted: make(chan struct{}),
unblockClose: make(chan struct{}),
}
w := &WGIface{
tun: tun,
wgProxyFactory: fakeProxyFactory{},
}
closeDone := make(chan error, 1)
go func() {
closeDone <- w.Close()
}()
select {
case <-tun.closeStarted:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
t.Fatal("tun.Close() was never invoked")
}
// Simulate the WireGuard read goroutine calling back into WGIface
// via the packet filter's DNS hook. If Close() still held w.mu
// during tun.Close(), this would block until the test timeout.
getDeviceDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = w.GetDevice()
close(getDeviceDone)
}()
select {
case <-getDeviceDone:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
wg.Wait()
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
}
close(tun.unblockClose)
select {
case <-closeDone:
case <-time.After(2 * time.Second):
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
}
}

View File

@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
}
if u.address.Network.Contains(a) {
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}

View File

@@ -201,18 +201,7 @@ Pop $0
Function .onInit
StrCpy $INSTDIR "${INSTALL_DIR}"
; Default autostart to enabled so silent installs (/S) match the interactive default
StrCpy $AutostartEnabled "1"
; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live
; in the 32-bit view. Fall back to it so upgrades still find them.
SetRegView 64
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
${If} $R0 == ""
SetRegView 32
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
SetRegView 64
${EndIf}
${If} $R0 != ""
# if silent install jump to uninstall step
IfSilent uninstall
@@ -225,10 +214,6 @@ ${If} $R0 != ""
${EndIf}
FunctionEnd
Function un.onInit
SetRegView 64
FunctionEnd
######################################################################
Section -MainProgram
${INSTALL_TYPE}
@@ -243,7 +228,6 @@ Section -MainProgram
!else
File /r "..\\dist\\netbird_windows_amd64\\"
!endif
File "..\\client\\ui\\assets\\netbird.png"
SectionEnd
######################################################################
@@ -263,11 +247,9 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -301,8 +283,6 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
@@ -341,7 +321,6 @@ DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM

View File

@@ -94,7 +94,6 @@ func (c *ConnectClient) RunOnAndroid(
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
stateFilePath string,
cacheDir string,
) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
@@ -104,7 +103,6 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, "")
}
@@ -333,10 +331,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp)
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
relayURLs = override
}
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
@@ -344,7 +338,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Error(err)
return wrapErr(err)
}
engineConfig.TempDir = mobileDependency.TempDir
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager)

View File

@@ -16,6 +16,7 @@ import (
"path/filepath"
"runtime"
"runtime/pprof"
"slices"
"sort"
"strings"
"time"
@@ -30,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const readmeContent = `Netbird debug bundle
@@ -232,7 +234,6 @@ type BundleGenerator struct {
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
cpuProfile []byte
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
@@ -255,7 +256,6 @@ type GeneratorDependencies struct {
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte
RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter
@@ -275,7 +275,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
@@ -288,7 +287,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
// Generate creates a debug bundle and returns the location.
func (g *BundleGenerator) Generate() (resp string, err error) {
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return "", fmt.Errorf("create zip file: %w", err)
}
@@ -374,8 +373,15 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if err := g.addPlatformLog(); err != nil {
log.Errorf("failed to add logs to debug bundle: %v", err)
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {

View File

@@ -1,41 +0,0 @@
//go:build android
package debug
import (
"fmt"
"io"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (g *BundleGenerator) addPlatformLog() error {
cmd := exec.Command("/system/bin/logcat", "-d")
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("logcat stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start logcat: %w", err)
}
var logReader io.Reader = stdout
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(stdout, pw, g.anonymizer)
}
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
return fmt.Errorf("add logcat to zip: %w", err)
}
if err := cmd.Wait(); err != nil {
return fmt.Errorf("wait logcat: %w", err)
}
log.Debug("added logcat output to debug bundle")
return nil
}

View File

@@ -1,25 +0,0 @@
//go:build !android
package debug
import (
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
func (g *BundleGenerator) addPlatformLog() error {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
return err
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
return err
}
return nil
}

View File

@@ -3,12 +3,10 @@ package debug
import (
"context"
"errors"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
@@ -21,10 +19,8 @@ func TestUpload(t *testing.T) {
t.Skip("Skipping upload test on docker ci")
}
testDir := t.TempDir()
addr := reserveLoopbackPort(t)
testURL := "http://" + addr
testURL := "http://localhost:8080"
t.Setenv("SERVER_URL", testURL)
t.Setenv("SERVER_ADDRESS", addr)
t.Setenv("STORE_DIR", testDir)
srv := server.NewServer()
go func() {
@@ -37,7 +33,6 @@ func TestUpload(t *testing.T) {
t.Errorf("Failed to stop server: %v", err)
}
})
waitForServer(t, addr)
file := filepath.Join(t.TempDir(), "tmpfile")
fileContent := []byte("test file content")
@@ -52,30 +47,3 @@ func TestUpload(t *testing.T) {
require.NoError(t, err)
require.Equal(t, fileContent, createdFileContent)
}
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
// address, then releases it so the server under test can rebind. The close/
// rebind window is racy in theory; on loopback with a kernel-assigned port
// it's essentially never contended in practice.
func reserveLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func waitForServer(t *testing.T, addr string) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
_ = c.Close()
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("server did not start listening on %s in time", addr)
}

View File

@@ -13,7 +13,6 @@ import (
const (
defaultResolvConfPath = "/etc/resolv.conf"
nsswitchConfPath = "/etc/nsswitch.conf"
)
type resolvConf struct {

View File

@@ -1,10 +1,7 @@
package dns
import (
"context"
"fmt"
"math"
"net"
"slices"
"strconv"
"strings"
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 {
return
}
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
// Try handlers in priority order
for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) {
continue
}
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime))
}
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
}
}
}
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,15 +1,11 @@
package dns_test
import (
"context"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
})
}
}
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -46,12 +46,12 @@ type restoreHostManager interface {
}
func newHostManager(wgInterface string) (hostManager, error) {
osManager, reason, err := getOSDNSManagerType()
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, fmt.Errorf("get os dns manager type: %w", err)
}
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
log.Infof("System DNS manager discovered: %s", osManager)
mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil {
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
}
}
func getOSDNSManagerType() (osManagerType, string, error) {
resolved := isSystemdResolvedRunning()
nss := isLibnssResolveUsed()
stub := checkStub()
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
// that go through nss-resolve, and in foreign mode they can loop back
// through resolved as an upstream.
if resolved && (nss || stub) {
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
}
mgr, reason, rejected, err := scanResolvConfHeader()
if err != nil {
return 0, "", err
}
if reason != "" {
return mgr, reason, nil
}
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
if len(rejected) > 0 {
fallback += "; rejected: " + strings.Join(rejected, ", ")
}
return fileManager, fallback, nil
}
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
// matching manager. If reason is empty the caller should pick file mode and
// use rejected for diagnostics.
func scanResolvConfHeader() (osManagerType, string, []string, error) {
func getOSDNSManagerType() (osManagerType, error) {
file, err := os.Open(defaultResolvConfPath)
if err != nil {
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
}
defer func() {
if cerr := file.Close(); cerr != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
if err := file.Close(); err != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
}
}()
var rejected []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
text := scanner.Text()
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
continue
}
if text[0] != '#' {
break
return fileManager, nil
}
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
return mgr, reason, nil, nil
} else if rej != "" {
rejected = append(rejected, rej)
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, nil
}
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
return fileManager, nil
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
}
}
if err := scanner.Err(); err != nil && err != io.EOF {
return 0, "", nil, fmt.Errorf("scan: %w", err)
return 0, fmt.Errorf("scan: %w", err)
}
return 0, "", rejected, nil
return fileManager, nil
}
// matchResolvConfHeader inspects a single comment line. Returns either a
// definitive (manager, reason) or a non-empty rejected diagnostic.
func matchResolvConfHeader(text string) (osManagerType, string, string) {
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, "netbird-managed resolv.conf header detected", ""
}
if strings.Contains(text, "NetworkManager") {
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, "NetworkManager header + supported version on dbus", ""
}
return 0, "", "NetworkManager header (no dbus or unsupported version)"
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
}
return resolvConfManager, "resolvconf header detected", ""
}
return 0, "", ""
}
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
// into file mode while resolved is active.
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
func checkStub() bool {
rConf, err := parseDefaultResolvConf()
if err != nil {
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
log.Warnf("failed to parse resolv conf: %s", err)
return true
}
@@ -178,36 +139,3 @@ func checkStub() bool {
return false
}
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
// delegated to systemd-resolved regardless of /etc/resolv.conf.
func isLibnssResolveUsed() bool {
bs, err := os.ReadFile(nsswitchConfPath)
if err != nil {
log.Debugf("read %s: %v", nsswitchConfPath, err)
return false
}
return parseNsswitchResolveAhead(bs)
}
func parseNsswitchResolveAhead(data []byte) bool {
for _, line := range strings.Split(string(data), "\n") {
if i := strings.IndexByte(line, '#'); i >= 0 {
line = line[:i]
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "hosts:" {
continue
}
for _, module := range fields[1:] {
switch module {
case "dns":
return false
case "resolve":
return true
}
}
}
return false
}

View File

@@ -1,76 +0,0 @@
//go:build (linux && !android) || freebsd
package dns
import "testing"
func TestParseNsswitchResolveAhead(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "resolve before dns with action token",
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
want: true,
},
{
name: "dns before resolve",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
want: false,
},
{
name: "debian default with only dns",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
want: false,
},
{
name: "neither resolve nor dns",
in: "hosts: files myhostname\n",
want: false,
},
{
name: "no hosts line",
in: "passwd: files systemd\ngroup: files systemd\n",
want: false,
},
{
name: "empty",
in: "",
want: false,
},
{
name: "comments and blank lines ignored",
in: "# comment\n\n# another\nhosts: resolve dns\n",
want: true,
},
{
name: "trailing inline comment",
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
want: true,
},
{
name: "hosts token must be the first field",
in: " hosts: resolve dns\n",
want: true,
},
{
name: "other db line mentioning resolve is ignored",
in: "networks: resolve\nhosts: dns\n",
want: false,
},
{
name: "only resolve, no dns",
in: "hosts: files resolve\n",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -2,83 +2,40 @@ package mgmt
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
const dnsTimeout = 5 * time.Second
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
records map[dns.Question]*cachedRecord
records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
}
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question][]dns.RR),
}
}
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver"
}
// SetChainResolver wires the handler chain used to refresh stale cache entries.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
// ServeDNS implements dns.Handler interface.
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
m.mutex.RLock()
cached, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
records, found := m.records[question]
m.mutex.RUnlock()
if !found {
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
}
}
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
if errA != nil && errAAAA != nil {
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
return err
}
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
if err := errors.Join(errA, errAAAA); err != nil {
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
}
now := time.Now()
m.mutex.Lock()
defer m.mutex.Unlock()
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
// untouched on error. Caller holds m.mutex.
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
// unique in-flight key; bursty stale hits share its channel. expected is the
// cachedRecord pointer observed by the caller; the refresh only mutates the
// cache if that pointer is still the one stored, so a stale in-flight refresh
// can't clobber a newer entry written by AddDomain or a competing refresh.
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
return err
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
if len(records) == 0 {
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
delete(m.records, qA)
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains
}
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -1,408 +0,0 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,7 +6,6 @@ import (
"net/url"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains())
}
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -212,7 +212,6 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{
ctx: ctx,

View File

@@ -26,7 +26,6 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
@@ -141,7 +140,6 @@ type EngineConfig struct {
ProfileConfig *profilemanager.Config
LogPath string
TempDir string
}
// EngineServices holds the external service dependencies required by the Engine.
@@ -571,7 +569,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start(peer.IsForceRelayed())
e.srWatcher.Start()
e.receiveSignalEvents()
e.receiveManagementEvents()
@@ -605,8 +603,6 @@ func (e *Engine) createFirewall() error {
return nil
}
firewalld.SetParentContext(e.ctx)
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil {
@@ -944,12 +940,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
return fmt.Errorf("update relay token: %w", err)
}
urls := update.Urls
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
urls = override
}
e.relayManager.UpdateServerURLs(urls)
e.relayManager.UpdateServerURLs(update.Urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.
@@ -1104,7 +1095,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
RefreshStatus: func() {
e.RunHealthProbes(true)

View File

@@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -3,6 +3,7 @@ package activity
import (
"net"
"net/netip"
"runtime"
"testing"
"time"
@@ -17,6 +18,10 @@ import (
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
func isBindListenerPlatform() bool {
return runtime.GOOS == "windows" || runtime.GOOS == "js"
}
// mockEndpointManager implements device.EndpointManager for testing
type mockEndpointManager struct {
endpoints map[netip.Addr]net.Conn
@@ -176,6 +181,10 @@ func TestBindListener_Close(t *testing.T) {
}
func TestManager_BindMode(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -217,6 +226,10 @@ func TestManager_BindMode(t *testing.T) {
}
func TestManager_BindMode_MultiplePeers(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}

View File

@@ -4,12 +4,14 @@ import (
"errors"
"net"
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -73,6 +75,16 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is used on Windows, JS, and netstack platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// - Netstack: Allows multiple instances on the same host without port conflicts.
// BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
return NewUDPListener(m.wgIface, peerCfg)
}
provider, ok := m.wgIface.(bindProvider)
if !ok {
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")

View File

@@ -22,8 +22,4 @@ type MobileDependency struct {
DnsManager dns.IosDnsManager
FileDescriptor int32
StateFilePath string
// TempDir is a writable directory for temporary files (e.g., debug bundle zip).
// On Android, this should be set to the app's cache directory.
TempDir string
}

View File

@@ -185,20 +185,17 @@ func (conn *Conn) Open(engineCtx context.Context) error {
conn.workerRelay = NewWorkerRelay(conn.ctx, conn.Log, isController(conn.config), conn.config, conn, conn.relayManager)
forceRelay := IsForceRelayed()
if !forceRelay {
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
if err != nil {
return err
}
conn.workerICE = workerICE
conn.handshaker = NewHandshaker(conn.Log, conn.config, conn.signaler, conn.workerICE, conn.workerRelay, conn.metricsStages)
conn.handshaker.AddRelayListener(conn.workerRelay.OnNewOffer)
if !forceRelay {
if !isForceRelayed() {
conn.handshaker.AddICEListener(conn.workerICE.OnNewOffer)
}
@@ -254,9 +251,7 @@ func (conn *Conn) Close(signalToRemote bool) {
conn.wgWatcherCancel()
}
conn.workerRelay.CloseConn()
if conn.workerICE != nil {
conn.workerICE.Close()
}
conn.workerICE.Close()
if conn.wgProxyRelay != nil {
err := conn.wgProxyRelay.CloseConn()
@@ -299,9 +294,7 @@ func (conn *Conn) OnRemoteAnswer(answer OfferAnswer) {
// OnRemoteCandidate Handles ICE connection Candidate provided by the remote peer.
func (conn *Conn) OnRemoteCandidate(candidate ice.Candidate, haRoutes route.HAMap) {
conn.dumpState.RemoteCandidate()
if conn.workerICE != nil {
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
conn.workerICE.OnRemoteCandidate(candidate, haRoutes)
}
// SetOnConnected sets a handler function to be triggered by Conn when a new connection to a remote peer established
@@ -719,35 +712,33 @@ func (conn *Conn) evalStatus() ConnStatus {
return StatusConnecting
}
// isConnectedOnAllWay evaluates the overall connection status based on ICE and Relay transports.
//
// The result is a tri-state:
// - ConnStatusConnected: all available transports are up
// - ConnStatusPartiallyConnected: relay is up but ICE is still pending/reconnecting
// - ConnStatusDisconnected: no working transport
func (conn *Conn) isConnectedOnAllWay() (status guard.ConnStatus) {
func (conn *Conn) isConnectedOnAllWay() (connected bool) {
// would be better to protect this with a mutex, but it could cause deadlock with Close function
defer func() {
if status == guard.ConnStatusDisconnected {
if !connected {
conn.logTraceConnState()
}
}()
iceWorkerCreated := conn.workerICE != nil
var iceInProgress bool
if iceWorkerCreated {
iceInProgress = conn.workerICE.InProgress()
// For JS platform: only relay connection is supported
if runtime.GOOS == "js" {
return conn.statusRelay.Get() == worker.StatusConnected
}
return evalConnStatus(connStatusInputs{
forceRelay: IsForceRelayed(),
peerUsesRelay: conn.workerRelay.IsRelayConnectionSupportedWithPeer(),
relayConnected: conn.statusRelay.Get() == worker.StatusConnected,
remoteSupportsICE: conn.handshaker.RemoteICESupported(),
iceWorkerCreated: iceWorkerCreated,
iceStatusConnecting: conn.statusICE.Get() != worker.StatusDisconnected,
iceInProgress: iceInProgress,
})
// For non-JS platforms: check ICE connection status
if conn.statusICE.Get() == worker.StatusDisconnected && !conn.workerICE.InProgress() {
return false
}
// If relay is supported with peer, it must also be connected
if conn.workerRelay.IsRelayConnectionSupportedWithPeer() {
if conn.statusRelay.Get() == worker.StatusDisconnected {
return false
}
}
return true
}
func (conn *Conn) enableWgWatcherIfNeeded(enabledTime time.Time) {
@@ -935,43 +926,3 @@ func isController(config ConnConfig) bool {
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
return remoteRosenpassPubKey != nil
}
func evalConnStatus(in connStatusInputs) guard.ConnStatus {
// "Relay up and needed" — the peer uses relay and the transport is connected.
relayUsedAndUp := in.peerUsesRelay && in.relayConnected
// Force-relay mode: ICE never runs. Relay is the only transport and must be up.
if in.forceRelay {
return boolToConnStatus(relayUsedAndUp)
}
// Remote peer doesn't support ICE, or we haven't created the worker yet:
// relay is the only possible transport.
if !in.remoteSupportsICE || !in.iceWorkerCreated {
return boolToConnStatus(relayUsedAndUp)
}
// ICE counts as "up" when the status is anything other than Disconnected, OR
// when a negotiation is currently in progress (so we don't spam offers while one is in flight).
iceUp := in.iceStatusConnecting || in.iceInProgress
// Relay side is acceptable if the peer doesn't rely on relay, or relay is connected.
relayOK := !in.peerUsesRelay || in.relayConnected
switch {
case iceUp && relayOK:
return guard.ConnStatusConnected
case relayUsedAndUp:
// Relay is up but ICE is down — partially connected.
return guard.ConnStatusPartiallyConnected
default:
return guard.ConnStatusDisconnected
}
}
func boolToConnStatus(connected bool) guard.ConnStatus {
if connected {
return guard.ConnStatusConnected
}
return guard.ConnStatusDisconnected
}

View File

@@ -13,20 +13,6 @@ const (
StatusConnected
)
// connStatusInputs is the primitive-valued snapshot of the state that drives the
// tri-state connection classification. Extracted so the decision logic can be unit-tested
// without constructing full Worker/Handshaker objects.
type connStatusInputs struct {
forceRelay bool // NB_FORCE_RELAY or JS/WASM
peerUsesRelay bool // remote peer advertises relay support AND local has relay
relayConnected bool // statusRelay reports Connected (independent of whether peer uses relay)
remoteSupportsICE bool // remote peer sent ICE credentials
iceWorkerCreated bool // local WorkerICE exists (false in force-relay mode)
iceStatusConnecting bool // statusICE is anything other than Disconnected
iceInProgress bool // a negotiation is currently in flight
}
// ConnStatus describe the status of a peer's connection
type ConnStatus int32

View File

@@ -1,201 +0,0 @@
package peer
import (
"testing"
"github.com/netbirdio/netbird/client/internal/peer/guard"
)
func TestEvalConnStatus_ForceRelay(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "force relay, peer uses relay, relay up",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: true,
},
want: guard.ConnStatusConnected,
},
{
name: "force relay, peer uses relay, relay down",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: true,
relayConnected: false,
},
want: guard.ConnStatusDisconnected,
},
{
name: "force relay, peer does NOT use relay - disconnected forever",
in: connStatusInputs{
forceRelay: true,
peerUsesRelay: false,
relayConnected: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_ICEUnavailable(t *testing.T) {
tests := []struct {
name string
in connStatusInputs
want guard.ConnStatus
}{
{
name: "remote does not support ICE, peer uses relay, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer uses relay, relay down",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE worker not yet created, relay up",
in: connStatusInputs{
peerUsesRelay: true,
relayConnected: true,
remoteSupportsICE: true,
iceWorkerCreated: false,
},
want: guard.ConnStatusConnected,
},
{
name: "remote does not support ICE, peer does not use relay",
in: connStatusInputs{
peerUsesRelay: false,
relayConnected: false,
remoteSupportsICE: false,
iceWorkerCreated: true,
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
if got := evalConnStatus(tc.in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v", got, tc.want)
}
})
}
}
func TestEvalConnStatus_FullyAvailable(t *testing.T) {
base := connStatusInputs{
remoteSupportsICE: true,
iceWorkerCreated: true,
}
tests := []struct {
name string
mutator func(*connStatusInputs)
want guard.ConnStatus
}{
{
name: "ICE connected, relay connected, peer uses relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE connected, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE InProgress only, peer does NOT use relay",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.iceStatusConnecting = false
in.iceInProgress = true
},
want: guard.ConnStatusConnected,
},
{
name: "ICE down, relay up, peer uses relay -> partial",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = true
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusPartiallyConnected,
},
{
name: "ICE down, peer does NOT use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = false
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
{
name: "ICE up, peer uses relay but relay down -> partial (relay required, ICE ignored)",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = true
in.relayConnected = false
in.iceStatusConnecting = true
},
// relayOK = false (peer uses relay but it's down), iceUp = true
// first switch arm fails (relayOK false), relayUsedAndUp = false (relay down),
// falls into default: Disconnected.
want: guard.ConnStatusDisconnected,
},
{
name: "ICE down, relay up but peer does not use relay -> disconnected",
mutator: func(in *connStatusInputs) {
in.peerUsesRelay = false
in.relayConnected = true // not actually used since peer doesn't rely on it
in.iceStatusConnecting = false
in.iceInProgress = false
},
want: guard.ConnStatusDisconnected,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
in := base
tc.mutator(&in)
if got := evalConnStatus(in); got != tc.want {
t.Fatalf("evalConnStatus = %v, want %v (inputs: %+v)", got, tc.want, in)
}
})
}
}

View File

@@ -7,38 +7,12 @@ import (
)
const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
)
func IsForceRelayed() bool {
func isForceRelayed() bool {
if runtime.GOOS == "js" {
return true
}
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
}
// OverrideRelayURLs returns the relay server URL list set in
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
// the override is active. When the env var is unset, the boolean is false
// and the caller should keep the list received from the management server.
// Intended for lab/debug scenarios where a peer must pin to a specific home
// relay regardless of what management offers.
func OverrideRelayURLs() ([]string, bool) {
raw := os.Getenv(EnvKeyNBHomeRelayServers)
if raw == "" {
return nil, false
}
parts := strings.Split(raw, ",")
urls := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
urls = append(urls, p)
}
}
if len(urls) == 0 {
return nil, false
}
return urls, true
}

View File

@@ -8,19 +8,7 @@ import (
log "github.com/sirupsen/logrus"
)
// ConnStatus represents the connection state as seen by the guard.
type ConnStatus int
const (
// ConnStatusDisconnected means neither ICE nor Relay is connected.
ConnStatusDisconnected ConnStatus = iota
// ConnStatusPartiallyConnected means Relay is connected but ICE is not.
ConnStatusPartiallyConnected
// ConnStatusConnected means all required connections are established.
ConnStatusConnected
)
type connStatusFunc func() ConnStatus
type isConnectedFunc func() bool
// Guard is responsible for the reconnection logic.
// It will trigger to send an offer to the peer then has connection issues.
@@ -32,14 +20,14 @@ type connStatusFunc func() ConnStatus
// - ICE candidate changes
type Guard struct {
log *log.Entry
isConnectedOnAllWay connStatusFunc
isConnectedOnAllWay isConnectedFunc
timeout time.Duration
srWatcher *SRWatcher
relayedConnDisconnected chan struct{}
iCEConnDisconnected chan struct{}
}
func NewGuard(log *log.Entry, isConnectedFn connStatusFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
func NewGuard(log *log.Entry, isConnectedFn isConnectedFunc, timeout time.Duration, srWatcher *SRWatcher) *Guard {
return &Guard{
log: log,
isConnectedOnAllWay: isConnectedFn,
@@ -69,17 +57,8 @@ func (g *Guard) SetICEConnDisconnected() {
}
}
// reconnectLoopWithRetry periodically checks the connection status and sends offers to re-establish connectivity.
//
// Behavior depends on the connection state reported by isConnectedOnAllWay:
// - Connected: no action, the peer is fully reachable.
// - Disconnected (neither ICE nor Relay): retries aggressively with exponential backoff (800ms doubling
// up to timeout), never gives up. This ensures rapid recovery when the peer has no connectivity at all.
// - PartiallyConnected (Relay up, ICE not): retries up to 3 times with exponential backoff, then switches
// to one attempt per hour. This limits signaling traffic when relay already provides connectivity.
//
// External events (relay/ICE disconnect, signal/relay reconnect, candidate changes) reset the retry
// counter and backoff ticker, giving ICE a fresh chance after network conditions change.
// reconnectLoopWithRetry periodically check the connection status.
// Try to send offer while the P2P is not established or while the Relay is not connected if is it supported
func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
srReconnectedChan := g.srWatcher.NewListener()
defer g.srWatcher.RemoveListener(srReconnectedChan)
@@ -89,47 +68,36 @@ func (g *Guard) reconnectLoopWithRetry(ctx context.Context, callback func()) {
tickerChannel := ticker.C
iceState := &iceRetryState{log: g.log}
defer iceState.reset()
for {
select {
case <-tickerChannel:
switch g.isConnectedOnAllWay() {
case ConnStatusConnected:
// all good, nothing to do
case ConnStatusDisconnected:
callback()
case ConnStatusPartiallyConnected:
if iceState.shouldRetry() {
callback()
} else {
iceState.enterHourlyMode()
ticker.Stop()
tickerChannel = iceState.hourlyC()
}
case t := <-tickerChannel:
if t.IsZero() {
g.log.Infof("retry timed out, stop periodic offer sending")
// after backoff timeout the ticker.C will be closed. We need to a dummy channel to avoid loop
tickerChannel = make(<-chan time.Time)
continue
}
if !g.isConnectedOnAllWay() {
callback()
}
case <-g.relayedConnDisconnected:
g.log.Debugf("Relay connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.newReconnectTicker(ctx)
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
iceState.reset()
case <-g.iCEConnDisconnected:
g.log.Debugf("ICE connection changed, reset reconnection ticker")
ticker.Stop()
ticker = g.newReconnectTicker(ctx)
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
iceState.reset()
case <-srReconnectedChan:
g.log.Debugf("has network changes, reset reconnection ticker")
ticker.Stop()
ticker = g.newReconnectTicker(ctx)
ticker = g.prepareExponentTicker(ctx)
tickerChannel = ticker.C
iceState.reset()
case <-ctx.Done():
g.log.Debugf("context is done, stop reconnect loop")
@@ -152,7 +120,7 @@ func (g *Guard) initialTicker(ctx context.Context) *backoff.Ticker {
return backoff.NewTicker(bo)
}
func (g *Guard) newReconnectTicker(ctx context.Context) *backoff.Ticker {
func (g *Guard) prepareExponentTicker(ctx context.Context) *backoff.Ticker {
bo := backoff.WithContext(&backoff.ExponentialBackOff{
InitialInterval: 800 * time.Millisecond,
RandomizationFactor: 0.1,

View File

@@ -1,61 +0,0 @@
package guard
import (
"time"
log "github.com/sirupsen/logrus"
)
const (
// maxICERetries is the maximum number of ICE offer attempts when relay is connected
maxICERetries = 3
// iceRetryInterval is the periodic retry interval after ICE retries are exhausted
iceRetryInterval = 1 * time.Hour
)
// iceRetryState tracks the limited ICE retry attempts when relay is already connected.
// After maxICERetries attempts it switches to a periodic hourly retry.
type iceRetryState struct {
log *log.Entry
retries int
hourly *time.Ticker
}
func (s *iceRetryState) reset() {
s.retries = 0
if s.hourly != nil {
s.hourly.Stop()
s.hourly = nil
}
}
// shouldRetry reports whether the caller should send another ICE offer on this tick.
// Returns false when the per-cycle retry budget is exhausted and the caller must switch
// to the hourly ticker via enterHourlyMode + hourlyC.
func (s *iceRetryState) shouldRetry() bool {
if s.hourly != nil {
s.log.Debugf("hourly ICE retry attempt")
return true
}
s.retries++
if s.retries <= maxICERetries {
s.log.Debugf("ICE retry attempt %d/%d", s.retries, maxICERetries)
return true
}
return false
}
// enterHourlyMode starts the hourly retry ticker. Must be called after shouldRetry returns false.
func (s *iceRetryState) enterHourlyMode() {
s.log.Infof("ICE retries exhausted (%d/%d), switching to hourly retry", maxICERetries, maxICERetries)
s.hourly = time.NewTicker(iceRetryInterval)
}
func (s *iceRetryState) hourlyC() <-chan time.Time {
if s.hourly == nil {
return nil
}
return s.hourly.C
}

View File

@@ -1,103 +0,0 @@
package guard
import (
"testing"
log "github.com/sirupsen/logrus"
)
func newTestRetryState() *iceRetryState {
return &iceRetryState{log: log.NewEntry(log.StandardLogger())}
}
func TestICERetryState_AllowsInitialBudget(t *testing.T) {
s := newTestRetryState()
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d, want true (budget = %d)", i, maxICERetries)
}
}
}
func TestICERetryState_ExhaustsAfterBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries; i++ {
_ = s.shouldRetry()
}
if s.shouldRetry() {
t.Fatalf("shouldRetry returned true after budget exhausted, want false")
}
}
func TestICERetryState_HourlyCNilBeforeEnterHourlyMode(t *testing.T) {
s := newTestRetryState()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel before enterHourlyMode")
}
}
func TestICERetryState_EnterHourlyModeArmsTicker(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
defer s.reset()
if s.hourlyC() == nil {
t.Fatalf("hourlyC returned nil after enterHourlyMode")
}
}
func TestICERetryState_ShouldRetryTrueInHourlyMode(t *testing.T) {
s := newTestRetryState()
s.enterHourlyMode()
defer s.reset()
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false in hourly mode, want true")
}
// Subsequent calls also return true — we keep retrying on each hourly tick.
if !s.shouldRetry() {
t.Fatalf("second shouldRetry returned false in hourly mode, want true")
}
}
func TestICERetryState_ResetRestoresBudget(t *testing.T) {
s := newTestRetryState()
for i := 0; i < maxICERetries+1; i++ {
_ = s.shouldRetry()
}
s.enterHourlyMode()
s.reset()
if s.hourlyC() != nil {
t.Fatalf("hourlyC returned non-nil channel after reset")
}
if s.retries != 0 {
t.Fatalf("retries = %d after reset, want 0", s.retries)
}
for i := 1; i <= maxICERetries; i++ {
if !s.shouldRetry() {
t.Fatalf("shouldRetry returned false on attempt %d after reset, want true", i)
}
}
}
func TestICERetryState_ResetIsIdempotent(t *testing.T) {
s := newTestRetryState()
s.reset()
s.reset() // second call must not panic or re-stop a nil ticker
if s.hourlyC() != nil {
t.Fatalf("hourlyC non-nil after double reset")
}
}

View File

@@ -39,7 +39,7 @@ func NewSRWatcher(signalClient chNotifier, relayManager chNotifier, iFaceDiscove
return srw
}
func (w *SRWatcher) Start(disableICEMonitor bool) {
func (w *SRWatcher) Start() {
w.mu.Lock()
defer w.mu.Unlock()
@@ -50,10 +50,8 @@ func (w *SRWatcher) Start(disableICEMonitor bool) {
ctx, cancel := context.WithCancel(context.Background())
w.cancelIceMonitor = cancel
if !disableICEMonitor {
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
}
iceMonitor := NewICEMonitor(w.iFaceDiscover, w.iceConfig, GetICEMonitorPeriod())
go iceMonitor.Start(ctx, w.onICEChanged)
w.signalClient.SetOnReconnectedListener(w.onReconnected)
w.relayManager.SetOnReconnectedListener(w.onReconnected)

View File

@@ -4,7 +4,6 @@ import (
"context"
"errors"
"sync"
"sync/atomic"
log "github.com/sirupsen/logrus"
@@ -44,10 +43,6 @@ type OfferAnswer struct {
SessionID *ICESessionID
}
func (o *OfferAnswer) hasICECredentials() bool {
return o.IceCredentials.UFrag != "" && o.IceCredentials.Pwd != ""
}
type Handshaker struct {
mu sync.Mutex
log *log.Entry
@@ -64,10 +59,6 @@ type Handshaker struct {
relayListener *AsyncOfferListener
iceListener func(remoteOfferAnswer *OfferAnswer)
// remoteICESupported tracks whether the remote peer includes ICE credentials in its offers/answers.
// When false, the local side skips ICE listener dispatch and suppresses ICE credentials in responses.
remoteICESupported atomic.Bool
// remoteOffersCh is a channel used to wait for remote credentials to proceed with the connection
remoteOffersCh chan OfferAnswer
// remoteAnswerCh is a channel used to wait for remote credentials answer (confirmation of our offer) to proceed with the connection
@@ -75,7 +66,7 @@ type Handshaker struct {
}
func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *WorkerICE, relay *WorkerRelay, metricsStages *MetricsStages) *Handshaker {
h := &Handshaker{
return &Handshaker{
log: log,
config: config,
signaler: signaler,
@@ -85,13 +76,6 @@ func NewHandshaker(log *log.Entry, config ConnConfig, signaler *Signaler, ice *W
remoteOffersCh: make(chan OfferAnswer),
remoteAnswerCh: make(chan OfferAnswer),
}
// assume remote supports ICE until we learn otherwise from received offers
h.remoteICESupported.Store(ice != nil)
return h
}
func (h *Handshaker) RemoteICESupported() bool {
return h.remoteICESupported.Load()
}
func (h *Handshaker) AddRelayListener(offer func(remoteOfferAnswer *OfferAnswer)) {
@@ -106,20 +90,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
for {
select {
case remoteOfferAnswer := <-h.remoteOffersCh:
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
h.log.Infof("received offer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
// Record signaling received for reconnection attempts
if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived()
}
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil && h.RemoteICESupported() {
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
}
@@ -128,20 +110,18 @@ func (h *Handshaker) Listen(ctx context.Context) {
continue
}
case remoteOfferAnswer := <-h.remoteAnswerCh:
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s, remote ICE supported: %t", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString(), remoteOfferAnswer.hasICECredentials())
h.log.Infof("received answer, running version %s, remote WireGuard listen port %d, session id: %s", remoteOfferAnswer.Version, remoteOfferAnswer.WgListenPort, remoteOfferAnswer.SessionIDString())
// Record signaling received for reconnection attempts
if h.metricsStages != nil {
h.metricsStages.RecordSignalingReceived()
}
h.updateRemoteICEState(&remoteOfferAnswer)
if h.relayListener != nil {
h.relayListener.Notify(&remoteOfferAnswer)
}
if h.iceListener != nil && h.RemoteICESupported() {
if h.iceListener != nil {
h.iceListener(&remoteOfferAnswer)
}
case <-ctx.Done():
@@ -203,18 +183,15 @@ func (h *Handshaker) sendAnswer() error {
}
func (h *Handshaker) buildOfferAnswer() OfferAnswer {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer := OfferAnswer{
IceCredentials: IceCredentials{uFrag, pwd},
WgListenPort: h.config.LocalWgPort,
Version: version.NetbirdVersion(),
RosenpassPubKey: h.config.RosenpassConfig.PubKey,
RosenpassAddr: h.config.RosenpassConfig.Addr,
}
if h.ice != nil && h.RemoteICESupported() {
uFrag, pwd := h.ice.GetLocalUserCredentials()
sid := h.ice.SessionID()
answer.IceCredentials = IceCredentials{uFrag, pwd}
answer.SessionID = &sid
SessionID: &sid,
}
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
@@ -223,18 +200,3 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
return answer
}
func (h *Handshaker) updateRemoteICEState(offer *OfferAnswer) {
hasICE := offer.hasICECredentials()
prev := h.remoteICESupported.Swap(hasICE)
if prev != hasICE {
if hasICE {
h.log.Infof("remote peer started sending ICE credentials")
} else {
h.log.Infof("remote peer stopped sending ICE credentials")
if h.ice != nil {
h.ice.Close()
}
}
}
}

View File

@@ -46,13 +46,9 @@ func (s *Signaler) Ready() bool {
// SignalOfferAnswer signals either an offer or an answer to remote peer
func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string, bodyType sProto.Body_Type) error {
var sessionIDBytes []byte
if offerAnswer.SessionID != nil {
var err error
sessionIDBytes, err = offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
sessionIDBytes, err := offerAnswer.SessionID.Bytes()
if err != nil {
log.Warnf("failed to get session ID bytes: %v", err)
}
msg, err := signal.MarshalCredential(
s.wgPrivateKey,

View File

@@ -1,10 +0,0 @@
//go:build (dragonfly || freebsd || netbsd || openbsd) && !darwin
package systemops
// Non-darwin BSDs don't support the IP_BOUND_IF + scoped default model. They
// always fall through to the ref-counter exclusion-route path; these stubs
// exist only so systemops_unix.go compiles.
func (r *SysOps) setupAdvancedRouting() error { return nil }
func (r *SysOps) cleanupAdvancedRouting() error { return nil }
func (r *SysOps) flushPlatformExtras() error { return nil }

View File

@@ -1,253 +0,0 @@
//go:build darwin && !ios
package systemops
import (
"errors"
"fmt"
"net/netip"
"os"
"time"
"github.com/hashicorp/go-multierror"
log "github.com/sirupsen/logrus"
"golang.org/x/net/route"
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
nbnet "github.com/netbirdio/netbird/client/net"
)
// scopedRouteBudget bounds retries for the scoped default route. Installing or
// deleting it matters enough that we're willing to spend longer waiting for the
// kernel reply than for per-prefix exclusion routes.
const scopedRouteBudget = 5 * time.Second
// setupAdvancedRouting installs an RTF_IFSCOPE default route per address family
// pinned to the current physical egress, so IP_BOUND_IF scoped lookups can
// resolve gateway'd destinations while the VPN's split default owns the
// unscoped table.
//
// Timing note: this runs during routeManager.Init, which happens before the
// VPN interface is created and before any peer routes propagate. The initial
// mgmt / signal / relay TCP dials always fire before this runs, so those
// sockets miss the IP_BOUND_IF binding and rely on the kernel's normal route
// lookup, which at that point correctly picks the physical default. Those
// already-established TCP flows keep their originally-selected interface for
// their lifetime on Darwin because the kernel caches the egress route
// per-socket at connect time; adding the VPN's 0/1 + 128/1 split default
// afterwards does not migrate them since the original en0 default stays in
// the table. Any subsequent reconnect via nbnet.NewDialer picks up the
// populated bound-iface cache and gets IP_BOUND_IF set cleanly.
func (r *SysOps) setupAdvancedRouting() error {
// Drop any previously-cached egress interface before reinstalling. On a
// refresh, a family that no longer resolves would otherwise keep the stale
// binding, causing new sockets to scope to an interface without a matching
// scoped default.
nbnet.ClearBoundInterfaces()
if err := r.flushScopedDefaults(); err != nil {
log.Warnf("flush residual scoped defaults: %v", err)
}
var merr *multierror.Error
installed := 0
for _, unspec := range []netip.Addr{netip.IPv4Unspecified(), netip.IPv6Unspecified()} {
ok, err := r.installScopedDefaultFor(unspec)
if err != nil {
merr = multierror.Append(merr, err)
continue
}
if ok {
installed++
}
}
if installed == 0 && merr != nil {
return nberrors.FormatErrorOrNil(merr)
}
if merr != nil {
log.Warnf("advanced routing setup partially succeeded: %v", nberrors.FormatErrorOrNil(merr))
}
return nil
}
// installScopedDefaultFor resolves the physical default nexthop for the given
// address family, installs a scoped default via it, and caches the iface for
// subsequent IP_BOUND_IF / IPV6_BOUND_IF socket binds.
func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
nexthop, err := GetNextHop(unspec)
if err != nil {
if errors.Is(err, vars.ErrRouteNotFound) {
return false, nil
}
return false, fmt.Errorf("get default nexthop for %s: %w", unspec, err)
}
if nexthop.Intf == nil {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
}
reused := false
if err := r.addScopedDefault(unspec, nexthop); err != nil {
if !errors.Is(err, unix.EEXIST) {
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
// macOS installs its own RTF_IFSCOPE defaults for primary service
// selection on multi-NIC setups, so a route on this ifindex can
// already exist before we try. Binding to it via IP[V6]_BOUND_IF
// still produces the scoped lookup we need.
reused = true
}
af := unix.AF_INET
if unspec.Is6() {
af = unix.AF_INET6
}
nbnet.SetBoundInterface(af, nexthop.Intf)
via := "point-to-point"
if nexthop.IP.IsValid() {
via = nexthop.IP.String()
}
verb := "installed"
if reused {
verb = "reused existing"
}
log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec))
return true, nil
}
func (r *SysOps) cleanupAdvancedRouting() error {
nbnet.ClearBoundInterfaces()
return r.flushScopedDefaults()
}
// flushPlatformExtras runs darwin-specific residual cleanup hooked into the
// generic FlushMarkedRoutes path, so a crashed daemon's scoped defaults get
// removed on the next boot regardless of whether a profile is brought up.
func (r *SysOps) flushPlatformExtras() error {
return r.flushScopedDefaults()
}
// flushScopedDefaults removes any scoped default routes tagged with routeProtoFlag.
// Safe to call at startup to clear residual entries from a prior session.
func (r *SysOps) flushScopedDefaults() error {
rib, err := retryFetchRIB()
if err != nil {
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
removed := 0
for _, msg := range msgs {
rtMsg, ok := msg.(*route.RouteMessage)
if !ok {
continue
}
if rtMsg.Flags&routeProtoFlag == 0 {
continue
}
if rtMsg.Flags&unix.RTF_IFSCOPE == 0 {
continue
}
info, err := MsgToRoute(rtMsg)
if err != nil {
log.Debugf("skip scoped flush: %v", err)
continue
}
if !info.Dst.IsValid() || info.Dst.Bits() != 0 {
continue
}
if err := r.deleteScopedRoute(rtMsg); err != nil {
merr = multierror.Append(merr, fmt.Errorf("delete scoped default %s on index %d: %w",
info.Dst, rtMsg.Index, err))
continue
}
removed++
log.Debugf("flushed residual scoped default %s on index %d", info.Dst, rtMsg.Index)
}
if removed > 0 {
log.Infof("flushed %d residual scoped default route(s)", removed)
}
return nberrors.FormatErrorOrNil(merr)
}
func (r *SysOps) addScopedDefault(unspec netip.Addr, nexthop Nexthop) error {
return r.scopedRouteSocket(unix.RTM_ADD, unspec, nexthop)
}
func (r *SysOps) deleteScopedRoute(rtMsg *route.RouteMessage) error {
// Preserve identifying flags from the stored route (including RTF_GATEWAY
// only if present); kernel-set bits like RTF_DONE don't belong on RTM_DELETE.
keep := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_GATEWAY | unix.RTF_IFSCOPE | routeProtoFlag
del := &route.RouteMessage{
Type: unix.RTM_DELETE,
Flags: rtMsg.Flags & keep,
Version: unix.RTM_VERSION,
Seq: r.getSeq(),
Index: rtMsg.Index,
Addrs: rtMsg.Addrs,
}
return r.writeRouteMessage(del, scopedRouteBudget)
}
func (r *SysOps) scopedRouteSocket(action int, unspec netip.Addr, nexthop Nexthop) error {
flags := unix.RTF_UP | unix.RTF_STATIC | unix.RTF_IFSCOPE | routeProtoFlag
msg := &route.RouteMessage{
Type: action,
Flags: flags,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
Index: nexthop.Intf.Index,
}
const numAddrs = unix.RTAX_NETMASK + 1
addrs := make([]route.Addr, numAddrs)
dst, err := addrToRouteAddr(unspec)
if err != nil {
return fmt.Errorf("build destination: %w", err)
}
mask, err := prefixToRouteNetmask(netip.PrefixFrom(unspec, 0))
if err != nil {
return fmt.Errorf("build netmask: %w", err)
}
addrs[unix.RTAX_DST] = dst
addrs[unix.RTAX_NETMASK] = mask
if nexthop.IP.IsValid() {
msg.Flags |= unix.RTF_GATEWAY
gw, err := addrToRouteAddr(nexthop.IP.Unmap())
if err != nil {
return fmt.Errorf("build gateway: %w", err)
}
addrs[unix.RTAX_GATEWAY] = gw
} else {
addrs[unix.RTAX_GATEWAY] = &route.LinkAddr{
Index: nexthop.Intf.Index,
Name: nexthop.Intf.Name,
}
}
msg.Addrs = addrs
return r.writeRouteMessage(msg, scopedRouteBudget)
}
func afOf(a netip.Addr) string {
if a.Is4() {
return "IPv4"
}
return "IPv6"
}

View File

@@ -0,0 +1,125 @@
//go:build darwin && !ios
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
nbnet "github.com/netbirdio/netbird/client/net"
)
// TestAfOf verifies that afOf returns the correct string for each address family.
func TestAfOf(t *testing.T) {
tests := []struct {
name string
addr netip.Addr
want string
}{
{
name: "IPv4 unspecified",
addr: netip.IPv4Unspecified(),
want: "IPv4",
},
{
name: "IPv4 private",
addr: netip.MustParseAddr("10.0.0.1"),
want: "IPv4",
},
{
name: "IPv4 loopback",
addr: netip.MustParseAddr("127.0.0.1"),
want: "IPv4",
},
{
name: "IPv6 unspecified",
addr: netip.IPv6Unspecified(),
want: "IPv6",
},
{
name: "IPv6 loopback",
addr: netip.MustParseAddr("::1"),
want: "IPv6",
},
{
name: "IPv6 unicast",
addr: netip.MustParseAddr("2001:db8::1"),
want: "IPv6",
},
{
name: "IPv6 link-local",
addr: netip.MustParseAddr("fe80::1"),
want: "IPv6",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
assert.Equal(t, tt.want, afOf(tt.addr))
})
}
}
// TestIsAddrRouted_AdvancedRoutingBypassesTunnelLookup verifies that when
// AdvancedRouting is active, IsAddrRouted immediately returns (false, zero)
// regardless of the provided vpn routes, because the WG socket is bound to
// the physical interface via IP_BOUND_IF and bypasses the main routing table.
func TestIsAddrRouted_AdvancedRoutingBypassesTunnelLookup(t *testing.T) {
// On darwin, AdvancedRouting returns true unless overridden.
// Ensure we reset the state after the test.
t.Setenv("NB_USE_LEGACY_ROUTING", "false")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
nbnet.Init()
require.True(t, nbnet.AdvancedRouting(), "test requires advanced routing to be active on darwin")
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.1.0/24"),
netip.MustParsePrefix("0.0.0.0/0"),
}
tests := []struct {
name string
addr netip.Addr
}{
{"IPv4 in VPN route", netip.MustParseAddr("10.0.0.1")},
{"IPv4 in narrow VPN route", netip.MustParseAddr("192.168.1.100")},
{"IPv4 default route covered", netip.MustParseAddr("8.8.8.8")},
{"IPv6 in VPN route", netip.MustParseAddr("2001:db8::1")},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
routed, prefix := IsAddrRouted(tt.addr, vpnRoutes)
assert.False(t, routed, "should not be marked as routed via VPN when advanced routing is active")
assert.Equal(t, netip.Prefix{}, prefix, "matched prefix should be zero when advanced routing is active")
})
}
}
// TestIsAddrRouted_LegacyModeFallsThroughToTable verifies that when
// NB_USE_LEGACY_ROUTING=true disables advanced routing, IsAddrRouted
// performs the normal VPN-route vs local-route comparison.
func TestIsAddrRouted_LegacyModeFallsThroughToTable(t *testing.T) {
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
nbnet.Init()
require.False(t, nbnet.AdvancedRouting(), "test requires advanced routing to be disabled")
// Use an address that is very unlikely to exist in the host routing table
// as a local route, so the VPN route wins.
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("198.51.100.0/24"), // TEST-NET-2 not in normal routing tables
}
addr := netip.MustParseAddr("198.51.100.1")
routed, _ := IsAddrRouted(addr, vpnRoutes)
// We cannot assert a specific outcome because it depends on the host's
// routing table, but we CAN assert that the call did not panic and returned
// a consistent pair.
_ = routed
}

View File

@@ -21,7 +21,6 @@ import (
"github.com/netbirdio/netbird/client/internal/routemanager/util"
"github.com/netbirdio/netbird/client/internal/routemanager/vars"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/net/hooks"
)
@@ -32,6 +31,8 @@ var splitDefaultv4_2 = netip.PrefixFrom(netip.AddrFrom4([4]byte{128}), 1)
var splitDefaultv6_1 = netip.PrefixFrom(netip.IPv6Unspecified(), 1)
var splitDefaultv6_2 = netip.PrefixFrom(netip.AddrFrom16([16]byte{0x80}), 1)
var ErrRoutingIsSeparate = errors.New("routing is separate")
func (r *SysOps) setupRefCounter(initAddresses []net.IP, stateManager *statemanager.Manager) error {
stateManager.RegisterState(&ShutdownState{})
@@ -396,16 +397,12 @@ func ipToAddr(ip net.IP, intf *net.Interface) (netip.Addr, error) {
}
// IsAddrRouted checks if the candidate address would route to the vpn, in which case it returns true and the matched prefix.
// When advanced routing is active the WG socket is bound to the physical interface (fwmark on linux,
// IP_UNICAST_IF on windows, IP_BOUND_IF on darwin) and bypasses the main routing table, so the check is skipped.
func IsAddrRouted(addr netip.Addr, vpnRoutes []netip.Prefix) (bool, netip.Prefix) {
if nbnet.AdvancedRouting() {
return false, netip.Prefix{}
}
localRoutes, err := GetRoutesFromTable()
localRoutes, err := hasSeparateRouting()
if err != nil {
log.Errorf("Failed to get routes: %v", err)
if !errors.Is(err, ErrRoutingIsSeparate) {
log.Errorf("Failed to get routes: %v", err)
}
return false, netip.Prefix{}
}

View File

@@ -0,0 +1,140 @@
//go:build !android && !ios
package systemops
import (
"net/netip"
"testing"
"github.com/stretchr/testify/assert"
nbnet "github.com/netbirdio/netbird/client/net"
)
// withLegacyRouting forcibly puts the nbnet package into legacy (non-advanced)
// routing mode for the duration of the test, then restores the previous value.
func withLegacyRouting(t *testing.T) {
t.Helper()
t.Setenv("NB_USE_LEGACY_ROUTING", "true")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
nbnet.Init()
t.Cleanup(func() {
// After the test, re-initialise with no overrides so that subsequent
// tests start with a clean state.
nbnet.Init()
})
}
// withAdvancedRouting attempts to enable advanced routing for the test.
// On platforms where Init() cannot produce AdvancedRouting()=true (e.g. Linux
// without root), the calling test is skipped.
func withAdvancedRouting(t *testing.T) {
t.Helper()
t.Setenv("NB_USE_LEGACY_ROUTING", "false")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
nbnet.Init()
t.Cleanup(func() {
nbnet.Init()
})
if !nbnet.AdvancedRouting() {
t.Skip("advanced routing not available in this environment (need root or darwin)")
}
}
// TestIsAddrRouted_AdvancedRoutingShortCircuit verifies that when
// AdvancedRouting() returns true, IsAddrRouted immediately returns
// (false, zero prefix) regardless of any VPN routes, because the WG socket is
// bound directly to the physical interface and bypasses the kernel routing table.
func TestIsAddrRouted_AdvancedRoutingShortCircuit(t *testing.T) {
withAdvancedRouting(t)
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
netip.MustParsePrefix("192.168.0.0/16"),
netip.MustParsePrefix("0.0.0.0/0"),
netip.MustParsePrefix("::/0"),
}
tests := []struct {
name string
addr string
}{
{"IPv4 matched by 10/8", "10.0.0.1"},
{"IPv4 matched by 192.168/16", "192.168.1.1"},
{"IPv4 caught by default", "8.8.8.8"},
{"IPv6 caught by default", "2001:db8::1"},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addr := netip.MustParseAddr(tt.addr)
routed, prefix := IsAddrRouted(addr, vpnRoutes)
assert.False(t, routed, "advanced routing must short-circuit VPN route check")
assert.Equal(t, netip.Prefix{}, prefix, "returned prefix must be zero under advanced routing")
})
}
}
// TestIsAddrRouted_AdvancedRouting_EmptyRoutes verifies the short-circuit with
// an empty vpnRoutes slice — the result must still be (false, zero).
func TestIsAddrRouted_AdvancedRouting_EmptyRoutes(t *testing.T) {
withAdvancedRouting(t)
routed, prefix := IsAddrRouted(netip.MustParseAddr("10.0.0.1"), nil)
assert.False(t, routed)
assert.Equal(t, netip.Prefix{}, prefix)
}
// TestIsAddrRouted_LegacyMode_NonVPNAddress verifies that under legacy routing
// an address that is not covered by any VPN prefix returns false.
func TestIsAddrRouted_LegacyMode_NonVPNAddress(t *testing.T) {
withLegacyRouting(t)
// 198.51.100.x (TEST-NET-2) is not normally in any routing table.
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("10.0.0.0/8"),
}
addr := netip.MustParseAddr("198.51.100.1")
routed, _ := IsAddrRouted(addr, vpnRoutes)
assert.False(t, routed, "address not in VPN routes should not be marked as VPN-routed")
}
// TestIsAddrRouted_LegacyMode_EmptyVPNRoutes verifies that an empty vpnRoutes
// slice always yields (false, zero prefix) even under legacy routing.
func TestIsAddrRouted_LegacyMode_EmptyVPNRoutes(t *testing.T) {
withLegacyRouting(t)
routed, prefix := IsAddrRouted(netip.MustParseAddr("10.0.0.1"), nil)
assert.False(t, routed)
assert.Equal(t, netip.Prefix{}, prefix)
}
// TestIsAddrRouted_AdvancedVsLegacy_ContrastiveBehaviour documents the
// contract difference between the two modes: with a VPN default route and an
// address that matches it, legacy mode may mark it as VPN-routed while advanced
// mode must never do so.
func TestIsAddrRouted_AdvancedVsLegacy_ContrastiveBehaviour(t *testing.T) {
vpnRoutes := []netip.Prefix{
netip.MustParsePrefix("0.0.0.0/0"),
}
addr := netip.MustParseAddr("8.8.8.8")
// --- advanced routing: must always return false ---
t.Run("advanced routing", func(t *testing.T) {
withAdvancedRouting(t)
routed, prefix := IsAddrRouted(addr, vpnRoutes)
assert.False(t, routed, "advanced routing must bypass VPN route lookup")
assert.Equal(t, netip.Prefix{}, prefix)
})
// --- legacy routing: delegates to kernel table check, does not panic ---
t.Run("legacy routing", func(t *testing.T) {
withLegacyRouting(t)
// We don't assert true/false here because it depends on the host
// routing table, but the call must not panic and must return
// a valid (bool, prefix) pair.
routed, prefix := IsAddrRouted(addr, vpnRoutes)
t.Logf("legacy IsAddrRouted(%s, %v) = (%v, %v)", addr, vpnRoutes, routed, prefix)
})
}

View File

@@ -22,6 +22,10 @@ func GetRoutesFromTable() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return []netip.Prefix{}, nil
}
// GetDetailedRoutesFromTable returns empty routes for WASM.
func GetDetailedRoutesFromTable() ([]DetailedRoute, error) {
return []DetailedRoute{}, nil

View File

@@ -894,6 +894,13 @@ func getAddressFamily(prefix netip.Prefix) int {
return netlink.FAMILY_V6
}
func hasSeparateRouting() ([]netip.Prefix, error) {
if !nbnet.AdvancedRouting() {
return GetRoutesFromTable()
}
return nil, ErrRoutingIsSeparate
}
func isOpErr(err error) bool {
// EAFTNOSUPPORT when ipv6 is disabled via sysctl, EOPNOTSUPP when disabled in boot options or otherwise not supported
if errors.Is(err, syscall.EAFNOSUPPORT) || errors.Is(err, syscall.EOPNOTSUPP) {

View File

@@ -48,6 +48,10 @@ func EnableIPForwarding() error {
return nil
}
func hasSeparateRouting() ([]netip.Prefix, error) {
return GetRoutesFromTable()
}
// GetIPRules returns IP rules for debugging (not supported on non-Linux platforms)
func GetIPRules() ([]IPRule, error) {
log.Infof("IP rules collection is not supported on %s", runtime.GOOS)

View File

@@ -25,9 +25,6 @@ import (
const (
envRouteProtoFlag = "NB_ROUTE_PROTO_FLAG"
// routeBudget bounds retries for per-prefix exclusion route programming.
routeBudget = 1 * time.Second
)
var routeProtoFlag int
@@ -44,42 +41,26 @@ func init() {
}
func (r *SysOps) SetupRouting(initAddresses []net.IP, stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.setupAdvancedRouting()
}
log.Infof("Using legacy routing setup with ref counters")
return r.setupRefCounter(initAddresses, stateManager)
}
func (r *SysOps) CleanupRouting(stateManager *statemanager.Manager, advancedRouting bool) error {
if advancedRouting {
return r.cleanupAdvancedRouting()
}
return r.cleanupRefCounter(stateManager)
}
// FlushMarkedRoutes removes single IP exclusion routes marked with the configured RTF_PROTO flag.
// On darwin it also flushes residual RTF_IFSCOPE scoped default routes so a
// crashed prior session can't leave crud in the table.
func (r *SysOps) FlushMarkedRoutes() error {
var merr *multierror.Error
if err := r.flushPlatformExtras(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush platform extras: %w", err))
}
rib, err := retryFetchRIB()
if err != nil {
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("fetch routing table: %w", err)))
return fmt.Errorf("fetch routing table: %w", err)
}
msgs, err := route.ParseRIB(route.RIBTypeRoute, rib)
if err != nil {
return nberrors.FormatErrorOrNil(multierror.Append(merr, fmt.Errorf("parse routing table: %w", err)))
return fmt.Errorf("parse routing table: %w", err)
}
var merr *multierror.Error
flushedCount := 0
for _, msg := range msgs {
@@ -136,12 +117,12 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return fmt.Errorf("invalid prefix: %s", prefix)
}
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return fmt.Errorf("build route message: %w", err)
}
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = 1 * time.Second
if err := r.writeRouteMessage(msg, routeBudget); err != nil {
if err := backoff.Retry(r.routeOp(action, prefix, nexthop), expBackOff); err != nil {
a := "add"
if action == unix.RTM_DELETE {
a = "remove"
@@ -151,91 +132,50 @@ func (r *SysOps) routeSocket(action int, prefix netip.Prefix, nexthop Nexthop) e
return nil
}
// writeRouteMessage sends a route message over AF_ROUTE and waits for the
// kernel's matching reply, retrying transient failures until budget elapses.
// Callers do not need to manage sockets or seq numbers themselves.
func (r *SysOps) writeRouteMessage(msg *route.RouteMessage, budget time.Duration) error {
expBackOff := backoff.NewExponentialBackOff()
expBackOff.InitialInterval = 50 * time.Millisecond
expBackOff.MaxInterval = 500 * time.Millisecond
expBackOff.MaxElapsedTime = budget
return backoff.Retry(func() error { return routeMessageRoundtrip(msg) }, expBackOff)
}
func routeMessageRoundtrip(msg *route.RouteMessage) error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("close routing socket: %v", err)
}
}()
tv := unix.Timeval{Sec: 1}
if err := unix.SetsockoptTimeval(fd, unix.SOL_SOCKET, unix.SO_RCVTIMEO, &tv); err != nil {
return backoff.Permanent(fmt.Errorf("set recv timeout: %w", err))
}
// AF_ROUTE is a broadcast channel: every route socket on the host sees
// every RTM_* event. With concurrent route programming the default
// per-socket queue overflows and our own reply gets dropped.
if err := unix.SetsockoptInt(fd, unix.SOL_SOCKET, unix.SO_RCVBUF, 1<<20); err != nil {
log.Debugf("set SO_RCVBUF on route socket: %v", err)
}
bytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal: %w", err))
}
if _, err = unix.Write(fd, bytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
return readRouteResponse(fd, msg.Type, msg.Seq)
}
// readRouteResponse reads from the AF_ROUTE socket until it sees a reply
// matching our write (same type, seq, and pid). AF_ROUTE SOCK_RAW is a
// broadcast channel: interface up/down, third-party route changes and neighbor
// discovery events can all land between our write and read, so we must filter.
func readRouteResponse(fd, wantType, wantSeq int) error {
pid := int32(os.Getpid())
resp := make([]byte, 2048)
deadline := time.Now().Add(time.Second)
for {
if time.Now().After(deadline) {
// Transient: under concurrent pressure the kernel can drop our reply
// from the socket buffer. Let backoff.Retry re-send with a fresh seq.
return fmt.Errorf("read: timeout waiting for route reply type=%d seq=%d", wantType, wantSeq)
}
n, err := unix.Read(fd, resp)
func (r *SysOps) routeOp(action int, prefix netip.Prefix, nexthop Nexthop) func() error {
operation := func() error {
fd, err := unix.Socket(syscall.AF_ROUTE, syscall.SOCK_RAW, syscall.AF_UNSPEC)
if err != nil {
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EWOULDBLOCK) {
// SO_RCVTIMEO fired while waiting; loop to re-check the absolute deadline.
continue
return fmt.Errorf("open routing socket: %w", err)
}
defer func() {
if err := unix.Close(fd); err != nil && !errors.Is(err, unix.EBADF) {
log.Warnf("failed to close routing socket: %v", err)
}
return backoff.Permanent(fmt.Errorf("read: %w", err))
}()
msg, err := r.buildRouteMessage(action, prefix, nexthop)
if err != nil {
return backoff.Permanent(fmt.Errorf("build route message: %w", err))
}
if n < int(unsafe.Sizeof(unix.RtMsghdr{})) {
continue
msgBytes, err := msg.Marshal()
if err != nil {
return backoff.Permanent(fmt.Errorf("marshal route message: %w", err))
}
hdr := (*unix.RtMsghdr)(unsafe.Pointer(&resp[0]))
// Darwin reflects the sender's pid on replies; matching (Type, Seq, Pid)
// uniquely identifies our own reply among broadcast traffic.
if int(hdr.Type) != wantType || int(hdr.Seq) != wantSeq || hdr.Pid != pid {
continue
if _, err = unix.Write(fd, msgBytes); err != nil {
if errors.Is(err, unix.ENOBUFS) || errors.Is(err, unix.EAGAIN) {
return fmt.Errorf("write: %w", err)
}
return backoff.Permanent(fmt.Errorf("write: %w", err))
}
if hdr.Errno != 0 {
return backoff.Permanent(fmt.Errorf("kernel: %w", syscall.Errno(hdr.Errno)))
respBuf := make([]byte, 2048)
n, err := unix.Read(fd, respBuf)
if err != nil {
return backoff.Permanent(fmt.Errorf("read route response: %w", err))
}
if n > 0 {
if err := r.parseRouteResponse(respBuf[:n]); err != nil {
return backoff.Permanent(err)
}
}
return nil
}
return operation
}
func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Nexthop) (msg *route.RouteMessage, err error) {
@@ -243,7 +183,6 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
Type: action,
Flags: unix.RTF_UP | routeProtoFlag,
Version: unix.RTM_VERSION,
ID: uintptr(os.Getpid()),
Seq: r.getSeq(),
}
@@ -282,6 +221,19 @@ func (r *SysOps) buildRouteMessage(action int, prefix netip.Prefix, nexthop Next
return msg, nil
}
func (r *SysOps) parseRouteResponse(buf []byte) error {
if len(buf) < int(unsafe.Sizeof(unix.RtMsghdr{})) {
return nil
}
rtMsg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
if rtMsg.Errno != 0 {
return fmt.Errorf("parse: %d", rtMsg.Errno)
}
return nil
}
// addrToRouteAddr converts a netip.Addr to the appropriate route.Addr (*route.Inet4Addr or *route.Inet6Addr).
func addrToRouteAddr(addr netip.Addr) (route.Addr, error) {
if addr.Is4() {

View File

@@ -2,358 +2,217 @@
package sleep
/*
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
#include <IOKit/pwr_mgt/IOPMLib.h>
#include <IOKit/IOMessage.h>
#include <CoreFoundation/CoreFoundation.h>
extern void sleepCallbackBridge();
extern void poweredOnCallbackBridge();
extern void suspendedCallbackBridge();
extern void resumedCallbackBridge();
// C global variables for IOKit state
static IONotificationPortRef g_notifyPortRef = NULL;
static io_object_t g_notifierObject = 0;
static io_object_t g_generalInterestNotifier = 0;
static io_connect_t g_rootPort = 0;
static CFRunLoopRef g_runLoop = NULL;
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
switch (messageType) {
case kIOMessageSystemWillSleep:
sleepCallbackBridge();
IOAllowPowerChange(g_rootPort, (long)messageArgument);
break;
case kIOMessageSystemHasPoweredOn:
poweredOnCallbackBridge();
break;
case kIOMessageServiceIsSuspended:
suspendedCallbackBridge();
break;
case kIOMessageServiceIsResumed:
resumedCallbackBridge();
break;
default:
break;
}
}
static void registerNotifications() {
g_rootPort = IORegisterForSystemPower(
NULL,
&g_notifyPortRef,
(IOServiceInterestCallback)sleepCallback,
&g_notifierObject
);
if (g_rootPort == 0) {
return;
}
CFRunLoopAddSource(CFRunLoopGetCurrent(),
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
g_runLoop = CFRunLoopGetCurrent();
CFRunLoopRun();
}
static void unregisterNotifications() {
CFRunLoopRemoveSource(g_runLoop,
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
IODeregisterForSystemPower(&g_notifierObject);
IOServiceClose(g_rootPort);
IONotificationPortDestroy(g_notifyPortRef);
CFRunLoopStop(g_runLoop);
g_notifyPortRef = NULL;
g_notifierObject = 0;
g_rootPort = 0;
g_runLoop = NULL;
}
*/
import "C"
import (
"context"
"fmt"
"runtime"
"sync"
"time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus"
)
// IOKit message types from IOKit/IOMessage.h.
const (
kIOMessageCanSystemSleep uintptr = 0xe0000270
kIOMessageSystemWillSleep uintptr = 0xe0000280
kIOMessageSystemHasPoweredOn uintptr = 0xe0000300
)
var (
ioKit iokitFuncs
cf cfFuncs
cfCommonModes uintptr
libInitOnce sync.Once
libInitErr error
// callbackThunk is the single C-callable trampoline registered with IOKit.
callbackThunk uintptr
serviceRegistry = make(map[*Detector]struct{})
serviceRegistryMu sync.Mutex
session *runLoopSession
// lifecycleMu serializes Register/Deregister so a new registration can't
// start a second runloop while a previous teardown is still pending.
lifecycleMu sync.Mutex
)
// iokitFuncs holds IOKit symbols resolved once at init.
type iokitFuncs struct {
IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr
IODeregisterForSystemPower func(notifier *uintptr) int32
IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32
IOServiceClose func(connect uintptr) int32
IONotificationPortGetRunLoopSource func(port uintptr) uintptr
IONotificationPortDestroy func(port uintptr)
}
// cfFuncs holds CoreFoundation symbols resolved once at init.
type cfFuncs struct {
CFRunLoopGetCurrent func() uintptr
CFRunLoopRun func()
CFRunLoopStop func(rl uintptr)
CFRunLoopAddSource func(rl, source, mode uintptr)
CFRunLoopRemoveSource func(rl, source, mode uintptr)
}
// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil
// session means no runloop is active and the next Register must start one.
type runLoopSession struct {
rl uintptr
port uintptr
notifier uintptr
rp uintptr
}
// detectorSnapshot pins a detector's callback and done channel so dispatch
// runs with values valid at snapshot time, even if a concurrent
// Deregister/Register rewrites the detector's fields.
type detectorSnapshot struct {
detector *Detector
callback func(event EventType)
done <-chan struct{}
}
// Detector delivers sleep and wake events to a registered callback.
type Detector struct {
callback func(event EventType)
done chan struct{}
}
// Register installs callback for power events. The first registration starts
// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit
// registration succeeds or fails; subsequent registrations just add to the
// dispatch set.
func (d *Detector) Register(callback func(event EventType)) error {
lifecycleMu.Lock()
defer lifecycleMu.Unlock()
//export sleepCallbackBridge
func sleepCallbackBridge() {
log.Info("sleepCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeSleep)
}
}
//export resumedCallbackBridge
func resumedCallbackBridge() {
log.Info("resumedCallbackBridge event triggered")
}
//export suspendedCallbackBridge
func suspendedCallbackBridge() {
log.Info("suspendedCallbackBridge event triggered")
}
//export poweredOnCallbackBridge
func poweredOnCallbackBridge() {
log.Info("poweredOnCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeWakeUp)
}
}
type Detector struct {
callback func(event EventType)
ctx context.Context
cancel context.CancelFunc
}
func NewDetector() (*Detector, error) {
return &Detector{}, nil
}
func (d *Detector) Register(callback func(event EventType)) error {
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
if _, exists := serviceRegistry[d]; exists {
serviceRegistryMu.Unlock()
return fmt.Errorf("detector service already registered")
}
d.callback = callback
d.done = make(chan struct{})
serviceRegistry[d] = struct{}{}
needSetup := session == nil
serviceRegistryMu.Unlock()
if !needSetup {
d.callback = callback
d.ctx, d.cancel = context.WithCancel(context.Background())
if len(serviceRegistry) > 0 {
serviceRegistry[d] = struct{}{}
return nil
}
errCh := make(chan error, 1)
go runRunLoop(errCh)
if err := <-errCh; err != nil {
serviceRegistryMu.Lock()
delete(serviceRegistry, d)
close(d.done)
d.done = nil
serviceRegistryMu.Unlock()
return err
}
serviceRegistry[d] = struct{}{}
// CFRunLoop must run on a single fixed OS thread
go func() {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
C.registerNotifications()
}()
log.Info("sleep detection service started on macOS")
return nil
}
// Deregister removes the detector. When the last detector leaves, IOKit
// notifications are torn down and the runloop is stopped.
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
// and the runloop is stopped and cleaned up.
func (d *Detector) Deregister() error {
lifecycleMu.Lock()
defer lifecycleMu.Unlock()
serviceRegistryMu.Lock()
if _, exists := serviceRegistry[d]; !exists {
serviceRegistryMu.Unlock()
defer serviceRegistryMu.Unlock()
_, exists := serviceRegistry[d]
if !exists {
return nil
}
close(d.done)
// cancel and remove this detector
d.cancel()
delete(serviceRegistry, d)
// If other Detectors still exist, leave IOKit running
if len(serviceRegistry) > 0 {
serviceRegistryMu.Unlock()
return nil
}
sess := session
serviceRegistryMu.Unlock()
log.Info("sleep detection service stopping (deregister)")
if sess == nil {
return nil
}
if sess.rl != 0 && sess.port != 0 {
source := ioKit.IONotificationPortGetRunLoopSource(sess.port)
cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes)
}
if sess.notifier != 0 {
n := sess.notifier
ioKit.IODeregisterForSystemPower(&n)
}
// Clear session only after IODeregisterForSystemPower returns so any
// in-flight powerCallback can still look up session.rp to ack sleep.
serviceRegistryMu.Lock()
session = nil
serviceRegistryMu.Unlock()
if sess.rp != 0 {
ioKit.IOServiceClose(sess.rp)
}
if sess.port != 0 {
ioKit.IONotificationPortDestroy(sess.port)
}
if sess.rl != 0 {
cf.CFRunLoopStop(sess.rl)
}
// Deregister IOKit notifications, stop runloop, and free resources
C.unregisterNotifications()
return nil
}
func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) {
if cb == nil || done == nil {
return
}
select {
case <-done:
return
default:
}
func (d *Detector) triggerCallback(event EventType) {
doneChan := make(chan struct{})
timeout := time.NewTimer(500 * time.Millisecond)
defer timeout.Stop()
go func() {
defer close(doneChan)
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep callback: %v", r)
}
}()
cb := d.callback
go func(callback func(event EventType)) {
log.Info("sleep detection event fired")
cb(event)
}()
callback(event)
close(doneChan)
}(cb)
select {
case <-doneChan:
case <-done:
case <-d.ctx.Done():
case <-timeout.C:
log.Warn("sleep callback timed out")
log.Warnf("sleep callback timed out")
}
}
// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector.
func NewDetector() (*Detector, error) {
if err := initLibs(); err != nil {
return nil, err
}
return &Detector{}, nil
}
func initLibs() error {
libInitOnce.Do(func() {
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libInitErr = fmt.Errorf("dlopen IOKit: %w", err)
return
}
cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err)
return
}
purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower")
purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower")
purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange")
purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose")
purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource")
purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy")
purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent")
purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun")
purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop")
purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource")
purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource")
modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes")
if err != nil {
libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err)
return
}
// Launder the uintptr-to-pointer conversion through a Go variable so
// go vet's unsafeptr analyzer doesn't flag a system-library global.
cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr))
// NewCallback slots are a finite, non-reclaimable resource, so register
// a single thunk that dispatches to the current Detector set.
callbackThunk = purego.NewCallback(powerCallback)
})
return libInitErr
}
// powerCallback is the IOServiceInterestCallback trampoline, invoked on the
// runloop thread. A Go panic crossing the purego boundary has undefined
// behavior, so contain it here.
func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr {
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep powerCallback: %v", r)
}
}()
switch messageType {
case kIOMessageCanSystemSleep:
// Not acknowledging forces a 30s IOKit timeout before idle sleep.
allowPowerChange(messageArgument)
case kIOMessageSystemWillSleep:
dispatchEvent(EventTypeSleep)
allowPowerChange(messageArgument)
case kIOMessageSystemHasPoweredOn:
dispatchEvent(EventTypeWakeUp)
}
return 0
}
func allowPowerChange(messageArgument uintptr) {
serviceRegistryMu.Lock()
var port uintptr
if session != nil {
port = session.rp
}
serviceRegistryMu.Unlock()
if port != 0 {
ioKit.IOAllowPowerChange(port, messageArgument)
}
}
func dispatchEvent(event EventType) {
serviceRegistryMu.Lock()
snaps := make([]detectorSnapshot, 0, len(serviceRegistry))
for d := range serviceRegistry {
snaps = append(snaps, detectorSnapshot{
detector: d,
callback: d.callback,
done: d.done,
})
}
serviceRegistryMu.Unlock()
for _, s := range snaps {
s.detector.triggerCallback(event, s.callback, s.done)
}
}
// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup
// result is reported on errCh so Register can surface failures synchronously.
func runRunLoop(errCh chan<- error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
sess, err := setupSession()
if err == nil {
serviceRegistryMu.Lock()
session = sess
serviceRegistryMu.Unlock()
}
errCh <- err
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep runloop: %v", r)
}
}()
cf.CFRunLoopRun()
}
// setupSession performs the IOKit registration on the current thread. Panics
// are converted to errors so runRunLoop never leaves errCh unsent.
func setupSession() (s *runLoopSession, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during runloop setup: %v", r)
}
}()
var portRef, notifier uintptr
rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, &notifier)
if rp == 0 {
return nil, fmt.Errorf("IORegisterForSystemPower returned zero")
}
rl := cf.CFRunLoopGetCurrent()
source := ioKit.IONotificationPortGetRunLoopSource(portRef)
cf.CFRunLoopAddSource(rl, source, cfCommonModes)
return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil
}

View File

@@ -1,5 +0,0 @@
package net
func (d *Dialer) init() {
d.Dialer.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin
//go:build !linux && !windows
package net

24
client/net/env_android.go Normal file
View File

@@ -0,0 +1,24 @@
//go:build android
package net
// Init initializes the network environment for Android
func Init() {
// No initialization needed on Android
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on Android since we cannot handle routes dynamically.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on Android
func SetVPNInterfaceName(name string) {
// No-op on Android - not needed for Android VPN service
}
// GetVPNInterfaceName returns empty string on Android
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -0,0 +1,121 @@
//go:build (darwin && !ios) || windows
package net
import (
"testing"
"github.com/stretchr/testify/assert"
)
// resetAdvancedRoutingState resets the package-level advancedRoutingSupported variable
// and cleans up after the test.
func resetAdvancedRoutingState(t *testing.T) {
t.Helper()
orig := advancedRoutingSupported
t.Cleanup(func() { advancedRoutingSupported = orig })
}
// TestCheckAdvancedRoutingSupport_LegacyRoutingTrue verifies that setting
// NB_USE_LEGACY_ROUTING=true disables advanced routing.
func TestCheckAdvancedRoutingSupport_LegacyRoutingTrue(t *testing.T) {
t.Setenv(envUseLegacyRouting, "true")
assert.False(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_LegacyRoutingFalse verifies that
// NB_USE_LEGACY_ROUTING=false still allows advanced routing when netstack is off.
func TestCheckAdvancedRoutingSupport_LegacyRoutingFalse(t *testing.T) {
t.Setenv(envUseLegacyRouting, "false")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
assert.True(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_LegacyRoutingInvalid verifies that an invalid
// value for NB_USE_LEGACY_ROUTING is ignored (treated as false), so advanced
// routing remains enabled when netstack is off.
func TestCheckAdvancedRoutingSupport_LegacyRoutingInvalid(t *testing.T) {
t.Setenv(envUseLegacyRouting, "notabool")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
// The invalid value is ignored; the default (false) is kept, so advanced routing
// is not suppressed by the legacy-routing flag.
assert.True(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_NetstackEnabled verifies that netstack mode
// disables advanced routing.
func TestCheckAdvancedRoutingSupport_NetstackEnabled(t *testing.T) {
t.Setenv(envUseLegacyRouting, "false")
t.Setenv("NB_USE_NETSTACK_MODE", "true")
assert.False(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_NoEnvVars verifies that with no env overrides
// advanced routing is supported (the happy path).
func TestCheckAdvancedRoutingSupport_NoEnvVars(t *testing.T) {
// Unset both controlling variables so we hit the default path.
t.Setenv(envUseLegacyRouting, "")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
assert.True(t, checkAdvancedRoutingSupport())
}
// TestCheckAdvancedRoutingSupport_LegacyRoutingEmptyString verifies that an
// empty NB_USE_LEGACY_ROUTING is treated as "not set" and does not disable
// advanced routing.
func TestCheckAdvancedRoutingSupport_LegacyRoutingEmptyString(t *testing.T) {
t.Setenv(envUseLegacyRouting, "")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
assert.True(t, checkAdvancedRoutingSupport())
}
// TestAdvancedRouting_ReflectsInit verifies that after calling Init() with
// NB_USE_LEGACY_ROUTING=true, AdvancedRouting() returns false.
func TestAdvancedRouting_ReflectsInit(t *testing.T) {
resetAdvancedRoutingState(t)
t.Setenv(envUseLegacyRouting, "true")
Init()
assert.False(t, AdvancedRouting(), "AdvancedRouting should return false after Init with legacy routing")
}
// TestAdvancedRouting_ReflectsInit_Advanced verifies that after calling Init()
// without legacy overrides, AdvancedRouting() returns true.
func TestAdvancedRouting_ReflectsInit_Advanced(t *testing.T) {
resetAdvancedRoutingState(t)
t.Setenv(envUseLegacyRouting, "false")
t.Setenv("NB_USE_NETSTACK_MODE", "false")
Init()
assert.True(t, AdvancedRouting(), "AdvancedRouting should return true after Init without legacy overrides")
}
// TestSetAndGetVPNInterfaceName verifies SetVPNInterfaceName and GetVPNInterfaceName
// are consistent.
func TestSetAndGetVPNInterfaceName(t *testing.T) {
orig := GetVPNInterfaceName()
t.Cleanup(func() { SetVPNInterfaceName(orig) })
SetVPNInterfaceName("utun3")
assert.Equal(t, "utun3", GetVPNInterfaceName())
}
// TestSetVPNInterfaceName_Empty verifies that setting an empty name is accepted.
func TestSetVPNInterfaceName_Empty(t *testing.T) {
orig := GetVPNInterfaceName()
t.Cleanup(func() { SetVPNInterfaceName(orig) })
SetVPNInterfaceName("")
assert.Equal(t, "", GetVPNInterfaceName())
}
// TestSetVPNInterfaceName_OverwritesPrevious verifies that the second call wins.
func TestSetVPNInterfaceName_OverwritesPrevious(t *testing.T) {
orig := GetVPNInterfaceName()
t.Cleanup(func() { SetVPNInterfaceName(orig) })
SetVPNInterfaceName("utun1")
SetVPNInterfaceName("utun9")
assert.Equal(t, "utun9", GetVPNInterfaceName())
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin
//go:build !linux && !windows && !android
package net

View File

@@ -1,25 +0,0 @@
//go:build ios || android
package net
// Init initializes the network environment for mobile platforms.
func Init() {
// no-op on mobile: routing scope is owned by the VPN extension.
}
// AdvancedRouting reports whether routing loops can be avoided without using exclusion routes.
// Always returns true on mobile since routes cannot be handled dynamically and the VPN extension
// owns the routing scope.
func AdvancedRouting() bool {
return true
}
// SetVPNInterfaceName is a no-op on mobile.
func SetVPNInterfaceName(string) {
// no-op on mobile: the VPN extension manages the interface.
}
// GetVPNInterfaceName returns an empty string on mobile.
func GetVPNInterfaceName() string {
return ""
}

View File

@@ -0,0 +1,39 @@
//go:build ios || android
package net
import (
"testing"
"github.com/stretchr/testify/assert"
)
// TestAdvancedRouting_Mobile verifies that AdvancedRouting always returns true
// on mobile platforms.
func TestAdvancedRouting_Mobile(t *testing.T) {
assert.True(t, AdvancedRouting(), "AdvancedRouting must always be true on mobile")
}
// TestInit_Mobile verifies that Init is a no-op and does not panic.
func TestInit_Mobile(t *testing.T) {
// Should not panic.
Init()
// After Init, AdvancedRouting must still return true.
assert.True(t, AdvancedRouting())
}
// TestSetVPNInterfaceName_Mobile verifies that SetVPNInterfaceName is a no-op
// and does not panic.
func TestSetVPNInterfaceName_Mobile(t *testing.T) {
// Should not panic for any input.
SetVPNInterfaceName("utun0")
SetVPNInterfaceName("")
}
// TestGetVPNInterfaceName_Mobile verifies that GetVPNInterfaceName always
// returns an empty string on mobile.
func TestGetVPNInterfaceName_Mobile(t *testing.T) {
// Even after a SetVPNInterfaceName call (no-op), the getter returns "".
SetVPNInterfaceName("utun0")
assert.Equal(t, "", GetVPNInterfaceName(), "GetVPNInterfaceName must return empty string on mobile")
}

View File

@@ -1,4 +1,4 @@
//go:build (darwin && !ios) || windows
//go:build windows
package net
@@ -24,22 +24,17 @@ func Init() {
}
func checkAdvancedRoutingSupport() bool {
legacyRouting := false
var err error
var legacyRouting bool
if val := os.Getenv(envUseLegacyRouting); val != "" {
parsed, err := strconv.ParseBool(val)
legacyRouting, err = strconv.ParseBool(val)
if err != nil {
log.Warnf("ignoring unparsable %s=%q: %v", envUseLegacyRouting, val, err)
} else {
legacyRouting = parsed
log.Warnf("failed to parse %s: %v", envUseLegacyRouting, err)
}
}
if legacyRouting {
log.Infof("advanced routing disabled: legacy routing requested via %s", envUseLegacyRouting)
return false
}
if netstack.IsEnabled() {
log.Info("advanced routing disabled: netstack mode is enabled")
if legacyRouting || netstack.IsEnabled() {
log.Info("advanced routing has been requested to be disabled")
return false
}

View File

@@ -1,5 +0,0 @@
package net
func (l *ListenerConfig) init() {
l.ListenConfig.Control = applyBoundIfToSocket
}

View File

@@ -1,4 +1,4 @@
//go:build !linux && !windows && !darwin
//go:build !linux && !windows
package net

View File

@@ -1,160 +0,0 @@
package net
import (
"fmt"
"net"
"net/netip"
"strconv"
"strings"
"sync"
"syscall"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// On darwin IPV6_BOUND_IF also scopes v4-mapped egress from dual-stack
// (IPV6_V6ONLY=0) AF_INET6 sockets, so a single setsockopt on "udp6"/"tcp6"
// covers both families. Setting IP_BOUND_IF on an AF_INET6 socket returns
// EINVAL regardless of V6ONLY because the IPPROTO_IP ctloutput path is
// dispatched by socket domain (AF_INET only) not by inp_vflag.
// boundIface holds the physical interface chosen at routing setup time. Sockets
// created via nbnet.NewDialer / nbnet.NewListener bind to it via IP_BOUND_IF
// (IPv4) or IPV6_BOUND_IF (IPv6 / dual-stack) so their scoped route lookup
// hits the RTF_IFSCOPE default installed by the routemanager, rather than
// following the VPN's split default.
var (
boundIfaceMu sync.RWMutex
boundIface4 *net.Interface
boundIface6 *net.Interface
)
// SetBoundInterface records the egress interface for an address family. Called
// by the routemanager after a scoped default route has been installed.
// af must be unix.AF_INET or unix.AF_INET6; other values are ignored.
// nil iface is rejected — use ClearBoundInterfaces to clear all slots.
func SetBoundInterface(af int, iface *net.Interface) {
if iface == nil {
log.Warnf("SetBoundInterface: nil iface for AF %d, ignored", af)
return
}
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
switch af {
case unix.AF_INET:
boundIface4 = iface
case unix.AF_INET6:
boundIface6 = iface
default:
log.Warnf("SetBoundInterface: unsupported address family %d", af)
}
}
// ClearBoundInterfaces resets the cached egress interfaces. Called by the
// routemanager during cleanup.
func ClearBoundInterfaces() {
boundIfaceMu.Lock()
defer boundIfaceMu.Unlock()
boundIface4 = nil
boundIface6 = nil
}
// boundInterfaceFor returns the cached egress interface for a socket's address
// family, falling back to the other family if the preferred slot is empty.
// The kernel stores both IP_BOUND_IF and IPV6_BOUND_IF in inp_boundifp, so
// either setsockopt scopes the socket; preferring same-family still matters
// when v4 and v6 defaults egress different NICs.
func boundInterfaceFor(network, address string) *net.Interface {
if iface := zoneInterface(address); iface != nil {
return iface
}
boundIfaceMu.RLock()
defer boundIfaceMu.RUnlock()
primary, secondary := boundIface4, boundIface6
if isV6Network(network) {
primary, secondary = boundIface6, boundIface4
}
if primary != nil {
return primary
}
return secondary
}
func isV6Network(network string) bool {
return strings.HasSuffix(network, "6")
}
// zoneInterface extracts an explicit interface from an IPv6 link-local zone (e.g. fe80::1%en0).
func zoneInterface(address string) *net.Interface {
if address == "" {
return nil
}
addr, err := netip.ParseAddrPort(address)
if err != nil {
a, err := netip.ParseAddr(address)
if err != nil {
return nil
}
addr = netip.AddrPortFrom(a, 0)
}
zone := addr.Addr().Zone()
if zone == "" {
return nil
}
if iface, err := net.InterfaceByName(zone); err == nil {
return iface
}
if idx, err := strconv.Atoi(zone); err == nil {
if iface, err := net.InterfaceByIndex(idx); err == nil {
return iface
}
}
return nil
}
func setIPv4BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IP, unix.IP_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IP_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
func setIPv6BoundIf(fd uintptr, iface *net.Interface) error {
if err := unix.SetsockoptInt(int(fd), unix.IPPROTO_IPV6, unix.IPV6_BOUND_IF, iface.Index); err != nil {
return fmt.Errorf("set IPV6_BOUND_IF: %w (interface: %s, index: %d)", err, iface.Name, iface.Index)
}
return nil
}
// applyBoundIfToSocket binds the socket to the cached physical egress interface
// so scoped route lookup avoids the VPN utun and egresses the underlay directly.
func applyBoundIfToSocket(network, address string, c syscall.RawConn) error {
if !AdvancedRouting() {
return nil
}
iface := boundInterfaceFor(network, address)
if iface == nil {
log.Debugf("no bound iface cached for %s to %s, skipping BOUND_IF", network, address)
return nil
}
isV6 := isV6Network(network)
var controlErr error
if err := c.Control(func(fd uintptr) {
if isV6 {
controlErr = setIPv6BoundIf(fd, iface)
} else {
controlErr = setIPv4BoundIf(fd, iface)
}
if controlErr == nil {
log.Debugf("set BOUND_IF=%d on %s for %s to %s", iface.Index, iface.Name, network, address)
}
}); err != nil {
return fmt.Errorf("control: %w", err)
}
return controlErr
}

View File

@@ -0,0 +1,286 @@
//go:build darwin && !ios
package net
import (
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/sys/unix"
)
// resetBoundIfaces clears global state before each test that manipulates it.
func resetBoundIfaces(t *testing.T) {
t.Helper()
ClearBoundInterfaces()
t.Cleanup(ClearBoundInterfaces)
}
// TestIsV6Network verifies the isV6Network helper correctly identifies v6 networks.
func TestIsV6Network(t *testing.T) {
tests := []struct {
network string
want bool
}{
{"tcp6", true},
{"udp6", true},
{"ip6", true},
{"tcp", false},
{"udp", false},
{"tcp4", false},
{"udp4", false},
{"ip4", false},
{"ip", false},
{"", false},
// Arbitrary suffix-6 strings
{"unix6", true},
{"custom6", true},
}
for _, tt := range tests {
t.Run(tt.network, func(t *testing.T) {
assert.Equal(t, tt.want, isV6Network(tt.network))
})
}
}
// TestSetBoundInterface_AFINET sets boundIface4 via AF_INET.
func TestSetBoundInterface_AFINET(t *testing.T) {
resetBoundIfaces(t)
iface := &net.Interface{Index: 5, Name: "en0"}
SetBoundInterface(unix.AF_INET, iface)
boundIfaceMu.RLock()
got := boundIface4
boundIfaceMu.RUnlock()
require.NotNil(t, got)
assert.Equal(t, 5, got.Index)
assert.Equal(t, "en0", got.Name)
}
// TestSetBoundInterface_AFINET6 sets boundIface6 via AF_INET6.
func TestSetBoundInterface_AFINET6(t *testing.T) {
resetBoundIfaces(t)
iface := &net.Interface{Index: 7, Name: "utun0"}
SetBoundInterface(unix.AF_INET6, iface)
boundIfaceMu.RLock()
got := boundIface6
boundIfaceMu.RUnlock()
require.NotNil(t, got)
assert.Equal(t, 7, got.Index)
assert.Equal(t, "utun0", got.Name)
}
// TestSetBoundInterface_Nil verifies nil iface is rejected without panicking.
func TestSetBoundInterface_Nil(t *testing.T) {
resetBoundIfaces(t)
// Should not panic, just log a warning and leave existing values untouched.
iface := &net.Interface{Index: 1, Name: "en1"}
SetBoundInterface(unix.AF_INET, iface)
SetBoundInterface(unix.AF_INET, nil)
boundIfaceMu.RLock()
got := boundIface4
boundIfaceMu.RUnlock()
// Value must be unchanged after the rejected nil write.
require.NotNil(t, got)
assert.Equal(t, 1, got.Index)
}
// TestSetBoundInterface_UnknownAF verifies unknown address families are ignored.
func TestSetBoundInterface_UnknownAF(t *testing.T) {
resetBoundIfaces(t)
iface := &net.Interface{Index: 3, Name: "en2"}
// Use an address family that is not AF_INET or AF_INET6.
SetBoundInterface(99, iface)
boundIfaceMu.RLock()
v4, v6 := boundIface4, boundIface6
boundIfaceMu.RUnlock()
assert.Nil(t, v4, "unknown AF must not populate boundIface4")
assert.Nil(t, v6, "unknown AF must not populate boundIface6")
}
// TestClearBoundInterfaces clears both cached interfaces.
func TestClearBoundInterfaces(t *testing.T) {
iface4 := &net.Interface{Index: 1, Name: "en0"}
iface6 := &net.Interface{Index: 2, Name: "en0"}
SetBoundInterface(unix.AF_INET, iface4)
SetBoundInterface(unix.AF_INET6, iface6)
ClearBoundInterfaces()
boundIfaceMu.RLock()
v4, v6 := boundIface4, boundIface6
boundIfaceMu.RUnlock()
assert.Nil(t, v4, "boundIface4 must be nil after clear")
assert.Nil(t, v6, "boundIface6 must be nil after clear")
}
// TestClearBoundInterfaces_Idempotent verifies clearing twice does not panic.
func TestClearBoundInterfaces_Idempotent(t *testing.T) {
ClearBoundInterfaces()
ClearBoundInterfaces()
boundIfaceMu.RLock()
v4, v6 := boundIface4, boundIface6
boundIfaceMu.RUnlock()
assert.Nil(t, v4)
assert.Nil(t, v6)
}
// TestBoundInterfaceFor_PreferSameFamily verifies v4 iface returned for "tcp" and
// v6 iface returned for "tcp6" when both slots are populated.
func TestBoundInterfaceFor_PreferSameFamily(t *testing.T) {
resetBoundIfaces(t)
en0 := &net.Interface{Index: 1, Name: "en0"}
en1 := &net.Interface{Index: 2, Name: "en1"}
SetBoundInterface(unix.AF_INET, en0)
SetBoundInterface(unix.AF_INET6, en1)
got4 := boundInterfaceFor("tcp", "1.2.3.4:80")
require.NotNil(t, got4)
assert.Equal(t, "en0", got4.Name, "tcp should prefer v4 interface")
got6 := boundInterfaceFor("tcp6", "[::1]:80")
require.NotNil(t, got6)
assert.Equal(t, "en1", got6.Name, "tcp6 should prefer v6 interface")
}
// TestBoundInterfaceFor_FallbackToOtherFamily returns the other family's iface
// when the preferred slot is empty.
func TestBoundInterfaceFor_FallbackToOtherFamily(t *testing.T) {
resetBoundIfaces(t)
// Only v4 populated.
en0 := &net.Interface{Index: 1, Name: "en0"}
SetBoundInterface(unix.AF_INET, en0)
// Asking for v6 should fall back to en0.
got := boundInterfaceFor("tcp6", "[::1]:80")
require.NotNil(t, got)
assert.Equal(t, "en0", got.Name)
}
// TestBoundInterfaceFor_BothEmpty returns nil when both slots are empty.
func TestBoundInterfaceFor_BothEmpty(t *testing.T) {
resetBoundIfaces(t)
got := boundInterfaceFor("tcp", "1.2.3.4:80")
assert.Nil(t, got)
got6 := boundInterfaceFor("tcp6", "[::1]:80")
assert.Nil(t, got6)
}
// TestZoneInterface_Empty returns nil for empty address.
func TestZoneInterface_Empty(t *testing.T) {
iface := zoneInterface("")
assert.Nil(t, iface)
}
// TestZoneInterface_NoZone returns nil when address has no zone.
func TestZoneInterface_NoZone(t *testing.T) {
// Regular IPv4 address with port — no zone identifier.
iface := zoneInterface("192.168.1.1:80")
assert.Nil(t, iface)
// Regular IPv6 address with port — no zone identifier.
iface = zoneInterface("[2001:db8::1]:80")
assert.Nil(t, iface)
// Plain IPv6 address without port or zone.
iface = zoneInterface("2001:db8::1")
assert.Nil(t, iface)
}
// TestZoneInterface_InvalidAddress returns nil for completely invalid strings.
func TestZoneInterface_InvalidAddress(t *testing.T) {
iface := zoneInterface("not-an-address")
assert.Nil(t, iface)
iface = zoneInterface("::::")
assert.Nil(t, iface)
}
// TestZoneInterface_NonExistentZoneName returns nil for a zone name that does
// not correspond to a real interface on the host.
func TestZoneInterface_NonExistentZoneName(t *testing.T) {
// Use an interface name that is very unlikely to exist.
iface := zoneInterface("fe80::1%nonexistentiface99999")
assert.Nil(t, iface)
}
// TestZoneInterface_NonExistentZoneIndex returns nil for a zone expressed as an
// integer index that is not in use.
func TestZoneInterface_NonExistentZoneIndex(t *testing.T) {
// Interface index 999999 should not exist on any test machine.
iface := zoneInterface("fe80::1%999999")
assert.Nil(t, iface)
}
// TestBoundInterfaceFor_SetThenClear verifies that clearing state causes
// boundInterfaceFor to return nil afterwards.
func TestBoundInterfaceFor_SetThenClear(t *testing.T) {
resetBoundIfaces(t)
en0 := &net.Interface{Index: 1, Name: "en0"}
SetBoundInterface(unix.AF_INET, en0)
got := boundInterfaceFor("tcp", "1.2.3.4:80")
require.NotNil(t, got, "should return iface while set")
ClearBoundInterfaces()
got = boundInterfaceFor("tcp", "1.2.3.4:80")
assert.Nil(t, got, "should return nil after clear")
}
// TestSetBoundInterface_OverwritesPreviousValue verifies that calling
// SetBoundInterface again updates the stored pointer.
func TestSetBoundInterface_OverwritesPreviousValue(t *testing.T) {
resetBoundIfaces(t)
first := &net.Interface{Index: 1, Name: "en0"}
second := &net.Interface{Index: 3, Name: "en1"}
SetBoundInterface(unix.AF_INET, first)
SetBoundInterface(unix.AF_INET, second)
boundIfaceMu.RLock()
got := boundIface4
boundIfaceMu.RUnlock()
require.NotNil(t, got)
assert.Equal(t, "en1", got.Name, "second call should overwrite first")
}
// TestBoundInterfaceFor_OnlyV6Populated returns v6 iface for v4 network
// when only v6 slot is filled.
func TestBoundInterfaceFor_OnlyV6Populated(t *testing.T) {
resetBoundIfaces(t)
en1 := &net.Interface{Index: 2, Name: "en1"}
SetBoundInterface(unix.AF_INET6, en1)
// v4 network, v4 slot empty → should fall back to v6 slot.
got := boundInterfaceFor("tcp", "1.2.3.4:80")
require.NotNil(t, got)
assert.Equal(t, "en1", got.Name)
}

View File

@@ -13,25 +13,15 @@
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
<!-- Autostart: enabled by default, disable with AUTOSTART=0 on the msiexec command line -->
<Property Id="AUTOSTART" Value="1" />
<StandardDirectory Id="ProgramFiles64Folder">
<Directory Id="NetbirdInstallDir" Name="Netbird">
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
</Shortcut>
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
</Shortcut>
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
</File>
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
<?if $(var.ArchSuffix) = "amd64" ?>
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
<?endif ?>
@@ -56,31 +46,8 @@
</Directory>
</StandardDirectory>
<!-- Per-user component: HKCU keypath (auto GUID via "*"), separate from
the per-machine NetbirdFiles component to satisfy ICE57. -->
<StandardDirectory Id="ProgramMenuFolder">
<Component Id="NetbirdAumidRegistry" Guid="*">
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
<Directory Id="NetbirdAutoStartDir" Name="Netbird">
<Component Id="NetbirdAutoStart" Guid="b199eaca-b0dd-4032-af19-679cfad48eb3" Bitness="always64">
<Condition>AUTOSTART = "1"</Condition>
<RegistryValue Root="HKLM" Key="Software\Microsoft\Windows\CurrentVersion\Run"
Name="Netbird" Value="&quot;[NetbirdInstallDir]netbird-ui.exe&quot;"
Type="string" KeyPath="yes" />
</Component>
</Directory>
</StandardDirectory>
<ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
</ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

File diff suppressed because it is too large Load Diff

View File

@@ -104,6 +104,8 @@ service DaemonService {
// StopCPUProfile stops CPU profiling in the daemon
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
// ExposeService exposes a local port via the NetBird reverse proxy
@@ -112,6 +114,20 @@ service DaemonService {
message OSLifecycleRequest {
// avoid collision with loglevel enum
enum CycleType {
UNKNOWN = 0;
SLEEP = 1;
WAKEUP = 2;
}
CycleType type = 1;
}
message OSLifecycleResponse {}
message LoginRequest {
// setupKey netbird setup key.
string setupKey = 1;

File diff suppressed because it is too large Load Diff

View File

@@ -120,7 +120,6 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
}
agent := &serverAgent{s}
s.sleepHandler = sleephandler.New(agent)
s.startSleepDetector()
return s
}

View File

@@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -2,18 +2,13 @@ package server
import (
"context"
"os"
"strconv"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/sleep"
"github.com/netbirdio/netbird/client/proto"
)
const envDisableSleepDetector = "NB_DISABLE_SLEEP_DETECTOR"
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
type serverAgent struct {
s *Server
@@ -33,61 +28,19 @@ func (a *serverAgent) Status() (internal.StatusType, error) {
return internal.CtxGetState(a.s.rootCtx).Status()
}
// startSleepDetector starts the OS sleep/wake detector and forwards events to
// the sleep handler. On platforms without a supported detector the attempt
// logs a warning and returns. Setting NB_DISABLE_SLEEP_DETECTOR=true skips
// registration entirely.
func (s *Server) startSleepDetector() {
if sleepDetectorDisabled() {
log.Info("sleep detection disabled via " + envDisableSleepDetector)
return
}
svc, err := sleep.New()
if err != nil {
log.Warnf("failed to initialize sleep detection: %v", err)
return
}
err = svc.Register(func(event sleep.EventType) {
switch event {
case sleep.EventTypeSleep:
log.Info("handling sleep event")
if err := s.sleepHandler.HandleSleep(s.rootCtx); err != nil {
log.Errorf("failed to handle sleep event: %v", err)
}
case sleep.EventTypeWakeUp:
log.Info("handling wakeup event")
if err := s.sleepHandler.HandleWakeUp(s.rootCtx); err != nil {
log.Errorf("failed to handle wakeup event: %v", err)
}
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
switch req.GetType() {
case proto.OSLifecycleRequest_WAKEUP:
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
})
if err != nil {
log.Errorf("failed to register sleep detector: %v", err)
return
}
log.Info("sleep detection service initialized")
go func() {
<-s.rootCtx.Done()
log.Info("stopping sleep event listener")
if err := svc.Deregister(); err != nil {
log.Errorf("failed to deregister sleep detector: %v", err)
case proto.OSLifecycleRequest_SLEEP:
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
return &proto.OSLifecycleResponse{}, err
}
}()
}
func sleepDetectorDisabled() bool {
val := os.Getenv(envDisableSleepDetector)
if val == "" {
return false
default:
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
}
disabled, err := strconv.ParseBool(val)
if err != nil {
log.Warnf("failed to parse %s=%q: %v", envDisableSleepDetector, val, err)
return false
}
return disabled
return &proto.OSLifecycleResponse{}, nil
}

View File

@@ -12,6 +12,7 @@ import (
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
"github.com/netbirdio/netbird/client/internal/statemanager"
nbnet "github.com/netbirdio/netbird/client/net"
"github.com/netbirdio/netbird/client/proto"
)
@@ -137,8 +138,10 @@ func restoreResidualState(ctx context.Context, statePath string) error {
}
// clean up any remaining routes independently of the state file
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
if !nbnet.AdvancedRouting() {
if err := systemops.New(nil, nil).FlushMarkedRoutes(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("flush marked routes: %w", err))
}
}
return nberrors.FormatErrorOrNil(merr)

View File

@@ -187,23 +187,24 @@ func (m *Manager) buildPeerConfig(allHostPatterns []string) (string, error) {
return "", fmt.Errorf("get NetBird executable path: %w", err)
}
hostList := strings.Join(deduplicatedPatterns, ",")
config := fmt.Sprintf("Match host \"%s\" exec \"%s ssh detect %%h %%p\"\n", hostList, execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
hostLine := strings.Join(deduplicatedPatterns, " ")
config := fmt.Sprintf("Host %s\n", hostLine)
config += fmt.Sprintf(" Match exec \"%s ssh detect %%h %%p\"\n", execPath)
config += " PreferredAuthentications password,publickey,keyboard-interactive\n"
config += " PasswordAuthentication yes\n"
config += " PubkeyAuthentication yes\n"
config += " BatchMode no\n"
config += fmt.Sprintf(" ProxyCommand %s ssh proxy %%h %%p\n", execPath)
config += " StrictHostKeyChecking no\n"
if runtime.GOOS == "windows" {
config += " UserKnownHostsFile NUL\n"
config += " UserKnownHostsFile NUL\n"
} else {
config += " UserKnownHostsFile /dev/null\n"
config += " UserKnownHostsFile /dev/null\n"
}
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
config += " CheckHostIP no\n"
config += " LogLevel ERROR\n\n"
return config, nil
}

View File

@@ -116,37 +116,6 @@ func TestManager_PeerLimit(t *testing.T) {
assert.True(t, os.IsNotExist(err), "SSH config should not be created with too many peers")
}
func TestManager_MatchHostFormat(t *testing.T) {
tempDir, err := os.MkdirTemp("", "netbird-ssh-config-test")
require.NoError(t, err)
defer func() { assert.NoError(t, os.RemoveAll(tempDir)) }()
manager := &Manager{
sshConfigDir: filepath.Join(tempDir, "ssh_config.d"),
sshConfigFile: "99-netbird.conf",
}
peers := []PeerSSHInfo{
{Hostname: "peer1", IP: "100.125.1.1", FQDN: "peer1.nb.internal"},
{Hostname: "peer2", IP: "100.125.1.2", FQDN: "peer2.nb.internal"},
}
err = manager.SetupSSHClientConfig(peers)
require.NoError(t, err)
configPath := filepath.Join(manager.sshConfigDir, manager.sshConfigFile)
content, err := os.ReadFile(configPath)
require.NoError(t, err)
configStr := string(content)
// Must use "Match host" with comma-separated patterns, not a bare "Host" directive.
// A bare "Host" followed by "Match exec" is incorrect per ssh_config(5): the Host block
// ends at the next Match keyword, making it a no-op and leaving the Match exec unscoped.
assert.NotContains(t, configStr, "\nHost ", "should not use bare Host directive")
assert.Contains(t, configStr, "Match host \"100.125.1.1,peer1.nb.internal,peer1,100.125.1.2,peer2.nb.internal,peer2\"",
"should use Match host with comma-separated patterns")
}
func TestManager_ForcedSSHConfig(t *testing.T) {
// Set force environment variable
t.Setenv(EnvForceSSHConfig, "true")

View File

@@ -2,6 +2,7 @@ package system
import (
"context"
"net"
"net/netip"
"strings"
@@ -144,6 +145,59 @@ func extractDeviceName(ctx context.Context, defaultName string) string {
return v
}
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
ipNet, ok := address.(*net.IPNet)
if !ok {
continue
}
if ipNet.IP.IsLoopback() {
continue
}
netAddr := NetworkAddress{
NetIP: netip.MustParsePrefix(ipNet.String()),
Mac: iface.HardwareAddr.String(),
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// GetInfoWithChecks retrieves and parses the system information with applied checks.
func GetInfoWithChecks(ctx context.Context, checks []*proto.Checks) (*Info, error) {
log.Debugf("gathering system information with checks: %d", len(checks))

View File

@@ -2,8 +2,6 @@ package system
import (
"context"
"net"
"net/netip"
"runtime"
log "github.com/sirupsen/logrus"
@@ -44,66 +42,6 @@ func GetInfo(ctx context.Context) *Info {
return gio
}
// networkAddresses returns the list of network addresses on iOS.
// On iOS, hardware (MAC) addresses are not available due to Apple's privacy
// restrictions (iOS returns a fixed 02:00:00:00:00:00 placeholder), so we
// leave Mac empty to match Android's behavior. We also skip the HardwareAddr
// check that other platforms use and filter out link-local addresses as they
// are not useful for posture checks.
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() || ipNet.IP.IsLinkLocalUnicast() || ipNet.IP.IsMulticast() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: ""}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}
// checkFileAndProcess checks if the file path exists and if a process is running at that path.
func checkFileAndProcess(paths []string) ([]File, error) {
return []File{}, nil

View File

@@ -1,66 +0,0 @@
//go:build !ios
package system
import (
"net"
"net/netip"
)
func networkAddresses() ([]NetworkAddress, error) {
interfaces, err := net.Interfaces()
if err != nil {
return nil, err
}
var netAddresses []NetworkAddress
for _, iface := range interfaces {
if iface.Flags&net.FlagUp == 0 {
continue
}
if iface.HardwareAddr.String() == "" {
continue
}
addrs, err := iface.Addrs()
if err != nil {
continue
}
mac := iface.HardwareAddr.String()
for _, address := range addrs {
netAddr, ok := toNetworkAddress(address, mac)
if !ok {
continue
}
if isDuplicated(netAddresses, netAddr) {
continue
}
netAddresses = append(netAddresses, netAddr)
}
}
return netAddresses, nil
}
func toNetworkAddress(address net.Addr, mac string) (NetworkAddress, bool) {
ipNet, ok := address.(*net.IPNet)
if !ok {
return NetworkAddress{}, false
}
if ipNet.IP.IsLoopback() {
return NetworkAddress{}, false
}
prefix, err := netip.ParsePrefix(ipNet.String())
if err != nil {
return NetworkAddress{}, false
}
return NetworkAddress{NetIP: prefix, Mac: mac}, true
}
func isDuplicated(addresses []NetworkAddress, addr NetworkAddress) bool {
for _, duplicated := range addresses {
if duplicated.NetIP == addr.NetIP {
return true
}
}
return false
}

View File

@@ -38,10 +38,10 @@ import (
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/sleep"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ui/desktop"
"github.com/netbirdio/netbird/client/ui/event"
"github.com/netbirdio/netbird/client/ui/notifier"
"github.com/netbirdio/netbird/client/ui/process"
"github.com/netbirdio/netbird/util"
@@ -260,7 +260,6 @@ type serviceClient struct {
// application with main windows.
app fyne.App
notifier notifier.Notifier
wSettings fyne.Window
showAdvancedSettings bool
sendNotification bool
@@ -365,7 +364,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
cancel: cancel,
addr: args.addr,
app: args.app,
notifier: notifier.New(args.app),
logFile: args.logFile,
sendNotification: false,
@@ -894,7 +892,7 @@ func (s *serviceClient) updateStatus() error {
if err != nil {
log.Errorf("get service status: %v", err)
if s.connected {
s.notifier.Send("Error", "Connection to service lost")
s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost"))
}
s.setDisconnectedStatus()
return err
@@ -1111,7 +1109,7 @@ func (s *serviceClient) onTrayReady() {
}
}()
s.eventManager = event.NewManager(s.notifier, s.addr)
s.eventManager = event.NewManager(s.app, s.addr)
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
if event.Category == proto.SystemEvent_SYSTEM {
@@ -1148,6 +1146,9 @@ func (s *serviceClient) onTrayReady() {
go s.eventManager.Start(s.ctx)
go s.eventHandler.listen(s.ctx)
// Start sleep detection listener
go s.startSleepListener()
}
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
@@ -1208,6 +1209,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
return s.conn, nil
}
// startSleepListener initializes the sleep detection service and listens for sleep events
func (s *serviceClient) startSleepListener() {
sleepService, err := sleep.New()
if err != nil {
log.Warnf("%v", err)
return
}
if err := sleepService.Register(s.handleSleepEvents); err != nil {
log.Errorf("failed to start sleep detection: %v", err)
return
}
log.Info("sleep detection service initialized")
// Cleanup on context cancellation
go func() {
<-s.ctx.Done()
log.Info("stopping sleep event listener")
if err := sleepService.Deregister(); err != nil {
log.Errorf("failed to deregister sleep detection: %v", err)
}
}()
}
// handleSleepEvents sends a sleep notification to the daemon via gRPC
func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
conn, err := s.getSrvClient(0)
if err != nil {
log.Errorf("failed to get daemon client for sleep notification: %v", err)
return
}
req := &proto.OSLifecycleRequest{}
switch event {
case sleep.EventTypeWakeUp:
log.Infof("handle wakeup event: %v", event)
req.Type = proto.OSLifecycleRequest_WAKEUP
case sleep.EventTypeSleep:
log.Infof("handle sleep event: %v", event)
req.Type = proto.OSLifecycleRequest_SLEEP
default:
log.Infof("unknown event: %v", event)
return
}
_, err = conn.NotifyOSLifecycle(s.ctx, req)
if err != nil {
log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
return
}
log.Info("successfully notified daemon about os lifecycle")
}
// setSettingsEnabled enables or disables the settings menu based on the provided state
func (s *serviceClient) setSettingsEnabled(enabled bool) {
if s.mSettings != nil {
@@ -1491,7 +1548,7 @@ func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) {
if enforced && s.lastNotifiedVersion != newVersion {
s.lastNotifiedVersion = newVersion
s.notifier.Send("Update available", "A new version "+newVersion+" is ready to install")
s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install"))
}
}

View File

@@ -8,6 +8,7 @@ import (
"sync"
"time"
"fyne.io/fyne/v2"
"github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc"
@@ -17,17 +18,11 @@ import (
"github.com/netbirdio/netbird/client/ui/desktop"
)
// Notifier sends desktop notifications. Defined here so the event package
// does not depend on fyne or the platform-specific notifier implementation.
type Notifier interface {
Send(title, body string)
}
type Handler func(*proto.SystemEvent)
type Manager struct {
notifier Notifier
addr string
app fyne.App
addr string
mu sync.Mutex
ctx context.Context
@@ -36,10 +31,10 @@ type Manager struct {
handlers []Handler
}
func NewManager(notifier Notifier, addr string) *Manager {
func NewManager(app fyne.App, addr string) *Manager {
return &Manager{
notifier: notifier,
addr: addr,
app: app,
addr: addr,
}
}
@@ -119,7 +114,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
if id != "" {
body += fmt.Sprintf(" ID: %s", id)
}
e.notifier.Send(title, body)
e.app.SendNotification(fyne.NewNotification(title, body))
}
for _, handler := range handlers {

View File

@@ -9,6 +9,7 @@ import (
"os"
"os/exec"
"fyne.io/fyne/v2"
"fyne.io/systray"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
@@ -86,7 +87,7 @@ func (h *eventHandler) handleConnectClick() {
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
log.Debugf("connect operation cancelled by user")
} else {
h.client.notifier.Send("Error", "Failed to connect")
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
log.Errorf("connect failed: %v", err)
}
}
@@ -111,7 +112,7 @@ func (h *eventHandler) handleDisconnectClick() {
if err := h.client.menuDownClick(); err != nil {
st, ok := status.FromError(err)
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
h.client.notifier.Send("Error", "Failed to disconnect")
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
log.Errorf("disconnect failed: %v", err)
} else {
log.Debugf("disconnect cancelled or already disconnecting")
@@ -129,7 +130,7 @@ func (h *eventHandler) handleAllowSSHClick() {
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update SSH settings")
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update SSH settings"))
}
}
@@ -139,7 +140,7 @@ func (h *eventHandler) handleAutoConnectClick() {
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update auto-connect settings")
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update auto-connect settings"))
}
}
@@ -148,7 +149,7 @@ func (h *eventHandler) handleRosenpassClick() {
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update Rosenpass settings")
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update Rosenpass settings"))
}
}
@@ -157,7 +158,7 @@ func (h *eventHandler) handleLazyConnectionClick() {
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update lazy connection settings")
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update lazy connection settings"))
}
}
@@ -166,7 +167,7 @@ func (h *eventHandler) handleBlockInboundClick() {
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update block inbound settings")
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update block inbound settings"))
}
}
@@ -175,7 +176,7 @@ func (h *eventHandler) handleNotificationsClick() {
if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error
log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update notifications settings")
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update notifications settings"))
} else if h.client.eventManager != nil {
h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked())
}

View File

@@ -1,27 +0,0 @@
// Package notifier sends desktop notifications. On Windows it uses the WinRT
// COM API directly via go-toast/v2 to avoid the PowerShell window flash that
// fyne's default implementation produces. On other platforms it delegates to
// fyne.
package notifier
import "fyne.io/fyne/v2"
// Notifier sends desktop notifications.
type Notifier interface {
Send(title, body string)
}
// New returns a platform-specific Notifier. The fyne app is used as the
// fallback notifier on platforms where no native implementation is wired up,
// and on Windows when the COM path fails to initialize.
func New(app fyne.App) Notifier {
return newNotifier(app)
}
type fyneNotifier struct {
app fyne.App
}
func (f *fyneNotifier) Send(title, body string) {
f.app.SendNotification(fyne.NewNotification(title, body))
}

View File

@@ -1,9 +0,0 @@
//go:build !windows
package notifier
import "fyne.io/fyne/v2"
func newNotifier(app fyne.App) Notifier {
return &fyneNotifier{app: app}
}

View File

@@ -1,88 +0,0 @@
package notifier
import (
"os"
"path/filepath"
"sync"
"fyne.io/fyne/v2"
toast "git.sr.ht/~jackmordaunt/go-toast/v2"
"git.sr.ht/~jackmordaunt/go-toast/v2/wintoast"
log "github.com/sirupsen/logrus"
)
const (
// appID is the AppUserModelID shown in the Windows Action Center. It
// must match the System.AppUserModel.ID property set on the Start Menu
// shortcut by the MSI (see client/netbird.wxs); otherwise Windows
// groups toasts under a separate, unbranded entry.
appID = "NetBird"
// appGUID identifies the COM activation callback class. Generated once
// for NetBird; do not change without coordinating an installer bump,
// since old registry entries pointing at the previous GUID would orphan.
appGUID = "{0E1B4DE7-E148-432B-9814-544F941826EC}"
)
type comNotifier struct {
fallback *fyneNotifier
ready bool
iconPath string
}
var (
initOnce sync.Once
initErr error
)
func newNotifier(app fyne.App) Notifier {
n := &comNotifier{
fallback: &fyneNotifier{app: app},
iconPath: resolveIcon(),
}
initOnce.Do(func() {
initErr = wintoast.SetAppData(wintoast.AppData{
AppID: appID,
GUID: appGUID,
IconPath: n.iconPath,
})
})
if initErr != nil {
log.Warnf("toast: register app data failed, falling back to fyne notifications: %v", initErr)
return n.fallback
}
n.ready = true
return n
}
func (n *comNotifier) Send(title, body string) {
if !n.ready {
n.fallback.Send(title, body)
return
}
notification := toast.Notification{
AppID: appID,
Title: title,
Body: body,
Icon: n.iconPath,
}
if err := notification.Push(); err != nil {
log.Warnf("toast: push failed, using fyne fallback: %v", err)
n.fallback.Send(title, body)
}
}
// resolveIcon returns an absolute path to the toast icon, or an empty string
// when no icon can be located. Windows requires a PNG/JPG for the
// AppUserModelId IconUri registry value; .ico is silently ignored.
func resolveIcon() string {
exe, err := os.Executable()
if err != nil {
return ""
}
candidate := filepath.Join(filepath.Dir(exe), "netbird.png")
if _, err := os.Stat(candidate); err == nil {
return candidate
}
return ""
}

View File

@@ -548,7 +548,7 @@ func (p *profileMenu) refresh() {
if err != nil {
log.Errorf("failed to switch profile: %v", err)
// show notification dialog
p.serviceClient.notifier.Send("Error", "Failed to switch profile")
p.app.SendNotification(fyne.NewNotification("Error", "Failed to switch profile"))
return
}
@@ -628,9 +628,9 @@ func (p *profileMenu) refresh() {
}
if err := p.eventHandler.logout(p.ctx); err != nil {
log.Errorf("logout failed: %v", err)
p.serviceClient.notifier.Send("Error", "Failed to deregister")
p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
} else {
p.serviceClient.notifier.Send("Success", "Deregistered successfully")
p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
}
}
}

View File

@@ -119,8 +119,6 @@ server:
# Reverse proxy settings (optional)
# reverseProxy:
# trustedHTTPProxies: [] # CIDRs of trusted reverse proxies (e.g. ["10.0.0.0/8"])
# trustedHTTPProxiesCount: 0 # Number of trusted proxies in front of the server (alternative to trustedHTTPProxies)
# trustedPeers: [] # CIDRs of trusted peer networks (e.g. ["100.64.0.0/10"])
# accessLogRetentionDays: 7 # Days to retain HTTP access logs. 0 (or unset) defaults to 7. Negative values disable cleanup (logs kept indefinitely).
# accessLogCleanupIntervalHours: 24 # How often (in hours) to run the access-log cleanup job. 0 (or unset) is treated as "not set" and defaults to 24 hours; cleanup remains enabled. To disable cleanup, set accessLogRetentionDays to a negative value.
# trustedHTTPProxies: []
# trustedHTTPProxiesCount: 0
# trustedPeers: []

View File

@@ -457,18 +457,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
client, err := flow.NewClient("http://"+server.addr, "test-payload", "test-signature", 1*time.Second)
require.NoError(t, err)
// Cleanups run LIFO: the goroutine-drain registered here runs after Close below,
// which is when Receive has actually returned. Without this, the Receive goroutine
// can outlive the test and call t.Logf after teardown, panicking.
receiveDone := make(chan struct{})
t.Cleanup(func() {
select {
case <-receiveDone:
case <-time.After(2 * time.Second):
t.Error("Receive goroutine did not exit after Close")
}
})
t.Cleanup(func() {
err := client.Close()
assert.NoError(t, err, "failed to close flow")
@@ -480,7 +468,6 @@ func TestReceive_ProtocolErrorStreamReconnect(t *testing.T) {
receivedAfterReconnect := make(chan struct{})
go func() {
defer close(receiveDone)
err := client.Receive(ctx, 1*time.Second, func(msg *proto.FlowEventAck) error {
if msg.IsInitiator || len(msg.EventId) == 0 {
return nil

5
go.mod
View File

@@ -30,7 +30,6 @@ require (
require (
fyne.io/fyne/v2 v2.7.0
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.38.3
github.com/aws/aws-sdk-go-v2/config v1.31.6
@@ -47,7 +46,6 @@ require (
github.com/crowdsecurity/go-cs-bouncer v0.0.21
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
github.com/dexidp/dex/api/v2 v2.4.0
github.com/ebitengine/purego v0.8.4
github.com/eko/gocache/lib/v4 v4.2.0
github.com/eko/gocache/store/go_cache/v4 v4.2.2
github.com/eko/gocache/store/redis/v4 v4.2.2
@@ -180,6 +178,7 @@ require (
github.com/docker/docker v28.0.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v1.1.1 // indirect
github.com/fyne-io/gl-js v0.2.0 // indirect
@@ -324,5 +323,3 @@ replace github.com/pion/ice/v4 => github.com/netbirdio/ice/v4 v4.0.0-20250908184
replace github.com/libp2p/go-netroute => github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944
replace github.com/dexidp/dex => github.com/netbirdio/dex v0.244.0
replace github.com/mailru/easyjson => github.com/netbirdio/easyjson v0.9.0

6
go.sum
View File

@@ -15,8 +15,6 @@ fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4=
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 h1:N3IGoHHp9pb6mj1cbXbuaSXV/UMKwmbKLf53nQmtqMA=
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3/go.mod h1:QtOLZGz8olr4qH2vWK0QH0w0O4T9fEIjMuWpKUsH7nc=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
@@ -402,6 +400,8 @@ github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae h1:dIZY4ULFcto4tA
github.com/lufia/plan9stats v0.0.0-20240513124658-fba389f38bae/go.mod h1:ilwx/Dta8jXAgpFYFvSWEMwxmbWXyiUHkd5FwyKhb5k=
github.com/magiconair/properties v1.8.10 h1:s31yESBquKXCV9a/ScB3ESkOjUYYv+X0rg8SYxI99mE=
github.com/magiconair/properties v1.8.10/go.mod h1:Dhd985XPs7jluiymwWYZ0G4Z61jb3vdS329zhj2hYo0=
github.com/mailru/easyjson v0.9.0 h1:PrnmzHw7262yW8sTBwxi1PdJA3Iw/EKBa8psRf7d9a4=
github.com/mailru/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/mattermost/xml-roundtrip-validator v0.1.0 h1:RXbVD2UAl7A7nOTR4u7E3ILa4IbtvKBHw64LDsmu9hU=
github.com/mattermost/xml-roundtrip-validator v0.1.0/go.mod h1:qccnGMcpgwcNaBnxqpJpWWUiPNr5H3O8eDgGV9gT5To=
github.com/mattn/go-isatty v0.0.9/go.mod h1:YNRxwqDuOph6SZLI9vUUz6OYw3QyUt7WiY2yME+cCiQ=
@@ -449,8 +449,6 @@ github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822 h1:C3w9PqII01/Oq
github.com/munnerz/goautoneg v0.0.0-20191010083416-a7dc8b61c822/go.mod h1:+n7T8mK8HuQTcFwEeznm/DIxMOiR9yIdICNftLE1DvQ=
github.com/netbirdio/dex v0.244.0 h1:1GOvi8wnXYassnKGildzNqRHq0RbcfEUw7LKYpKIN7U=
github.com/netbirdio/dex v0.244.0/go.mod h1:STGInJhPcAflrHmDO7vyit2kSq03PdL+8zQPoGALtcU=
github.com/netbirdio/easyjson v0.9.0 h1:6Nw2lghSVuy8RSkAYDhDv1thBVEmfVbKZnV7T7Z6Aus=
github.com/netbirdio/easyjson v0.9.0/go.mod h1:1+xMtQp2MRNVL/V1bOzuP3aP8VNwRW55fQUto+XFtTU=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944 h1:TDtJKmM6Sf8uYFx/dMeqNOL90KUoRscdfpFZ3Im89uk=
github.com/netbirdio/go-netroute v0.0.0-20240611143515-f59b0e1d3944/go.mod h1:sHA6TRxjQ6RLbnI+3R4DZo2Eseg/iKiPRfNmcuNySVQ=
github.com/netbirdio/ice/v4 v4.0.0-20250908184934-6202be846b51 h1:Ov4qdafATOgGMB1wbSuh+0aAHcwz9hdvB6VZjh1mVMI=

View File

@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
// are stored with types that Dex can open.
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
switch connType {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
return "oidc", applyOIDCDefaults(connType, config)
default:
return connType, config
@@ -218,8 +218,6 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
case "okta", "pocketid":
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
}
return augmented

View File

@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
var err error
switch cfg.Type {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
dexType = "oidc"
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
case "google":
@@ -220,8 +220,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
}
return encodeConnectorConfig(oidcConfig)
}
@@ -285,7 +283,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
// inferOIDCProviderType infers the specific OIDC provider from connector ID
func inferOIDCProviderType(connectorID string) string {
connectorIDLower := strings.ToLower(connectorID)
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} {
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
if strings.Contains(connectorIDLower, provider) {
return provider
}

View File

@@ -231,20 +231,7 @@ get_upstream_host() {
wait_management_proxy() {
local proxy_container="${1:-traefik}"
local use_docker_logs=false
set +e
if [[ "$proxy_container" == "detect-traefik" ]]; then
proxy_container=$(docker ps --format "{{.ID}}\t{{.Image}}\t{{.Ports}}" \
| awk -F'\t' '$2 ~ /traefik/ && $3 ~ /:(80|443)->/ {print $1; exit}')
if [[ -z "$proxy_container" ]]; then
echo "Warning: could not auto-detect Traefik container, log output will be skipped on timeout." > /dev/stderr
else
use_docker_logs=true
fi
fi
echo -n "Waiting for NetBird server to become ready"
counter=1
while true; do
@@ -255,13 +242,7 @@ wait_management_proxy() {
if [[ $counter -eq 60 ]]; then
echo ""
echo "Taking too long. Checking logs..."
if [[ -n "$proxy_container" ]]; then
if [[ "$use_docker_logs" == "true" ]]; then
docker logs --tail=20 "$proxy_container"
else
$DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
fi
fi
$DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
$DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server
fi
echo -n " ."
@@ -491,7 +472,7 @@ start_services_and_show_instructions() {
if [[ "$ENABLE_CROWDSEC" == "true" ]]; then
echo "Registering CrowdSec bouncer..."
local cs_retries=0
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli lapi status >/dev/null 2>&1; do
while ! $DOCKER_COMPOSE_COMMAND exec -T crowdsec cscli capi status >/dev/null 2>&1; do
cs_retries=$((cs_retries + 1))
if [[ $cs_retries -ge 30 ]]; then
echo "WARNING: CrowdSec did not become ready. Skipping CrowdSec setup." > /dev/stderr
@@ -537,7 +518,7 @@ start_services_and_show_instructions() {
$DOCKER_COMPOSE_COMMAND up -d
sleep 3
wait_management_proxy detect-traefik
wait_management_direct
echo -e "$MSG_DONE"
print_post_setup_instructions

View File

@@ -7,6 +7,7 @@ import (
"os"
"slices"
"strconv"
"strings"
"sync"
"sync/atomic"
"time"
@@ -15,9 +16,11 @@ import (
"golang.org/x/exp/maps"
"golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account"
@@ -55,6 +58,13 @@ type Controller struct {
proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
}
type bufferUpdate struct {
@@ -71,6 +81,29 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
}
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
compactedNetworkMap := true
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
if err != nil && len(compactedEnv) > 0 {
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
}
if err == nil && !parsedCompactedNmap {
log.WithContext(ctx).Info("disabling compacted mode")
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{
repo: newRepository(store),
metrics: nMetrics,
@@ -84,6 +117,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
}
}
@@ -114,9 +153,17 @@ func (c *Controller) CountStreams() int {
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
}
globalStart := time.Now()
@@ -150,6 +197,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
@@ -192,7 +243,16 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -258,6 +318,10 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
// UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
return fmt.Errorf("recalculate network map cache: %v", err)
}
return c.sendUpdateAccountPeers(ctx, accountID)
}
@@ -307,7 +371,16 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err
}
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -378,9 +451,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, emptyMap, nil, 0, nil
}
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
var (
account *types.Account
err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
}
account.InjectProxyPolicies(ctx)
@@ -412,10 +493,20 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, nil, 0, err
}
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(accountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {
@@ -427,6 +518,108 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, networkMap, postureChecks, dnsFwdPort, nil
}
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
c.enrichAccountFromHolder(account)
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil {
@@ -563,7 +756,16 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
}
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
}
@@ -573,6 +775,14 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
}
@@ -607,6 +817,19 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
MessageType: network_map.MessageTypeNetworkMap,
})
c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
}
return c.bufferSendUpdateAccountPeers(ctx, accountID)
@@ -649,11 +872,21 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return nil, err
}
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
var networkMap *types.NetworkMap
if c.experimentalNetworkMap(peer.AccountID) {
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
} else {
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok {

Some files were not shown because too many files have changed in this diff Show More