Compare commits

..

1 Commits

Author SHA1 Message Date
Viktor Liu
afbddae472 Add transparent proxy inspection engine with envoy sidecar support 2026-04-11 18:39:18 +02:00
357 changed files with 19322 additions and 20481 deletions

View File

@@ -1,62 +0,0 @@
name: Proto Version Check
on:
pull_request:
paths:
- "**/*.pb.go"
jobs:
check-proto-versions:
runs-on: ubuntu-latest
steps:
- name: Check for proto tool version changes
uses: actions/github-script@v7
with:
script: |
const files = await github.paginate(github.rest.pulls.listFiles, {
owner: context.repo.owner,
repo: context.repo.repo,
pull_number: context.issue.number,
per_page: 100,
});
const pbFiles = files.filter(f => f.filename.endsWith('.pb.go'));
const missingPatch = pbFiles.filter(f => !f.patch).map(f => f.filename);
if (missingPatch.length > 0) {
core.setFailed(
`Cannot inspect patch data for:\n` +
missingPatch.map(f => `- ${f}`).join('\n') +
`\nThis can happen with very large PRs. Verify proto versions manually.`
);
return;
}
const versionPattern = /^[+-]\s*\/\/\s+protoc(?:-gen-go)?\s+v[\d.]+/;
const violations = [];
for (const file of pbFiles) {
const changed = file.patch
.split('\n')
.filter(line => versionPattern.test(line));
if (changed.length > 0) {
violations.push({
file: file.filename,
lines: changed,
});
}
}
if (violations.length > 0) {
const details = violations.map(v =>
`${v.file}:\n${v.lines.map(l => ' ' + l).join('\n')}`
).join('\n\n');
core.setFailed(
`Proto version strings changed in generated files.\n` +
`This usually means the wrong protoc or protoc-gen-go version was used.\n` +
`Regenerate with the matching tool versions.\n\n` +
details
);
return;
}
console.log('No proto version string changes detected');

View File

@@ -9,7 +9,7 @@ on:
pull_request:
env:
SIGN_PIPE_VER: "v0.1.4"
SIGN_PIPE_VER: "v0.1.1"
GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH"
@@ -114,13 +114,7 @@ jobs:
retention-days: 30
release:
runs-on: ubuntu-24.04-8-core
outputs:
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
runs-on: ubuntu-latest-m
env:
flags: ""
steps:
@@ -219,13 +213,10 @@ jobs:
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only)
id: tag_and_push_images
if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
run: |
set -euo pipefail
resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}"
@@ -234,17 +225,6 @@ jobs:
fi
}
ghcr_package_url() {
local image="$1" package encoded_package
package="${image#ghcr.io/}"
package="${package#*/}"
package="${package%%:*}"
encoded_package="${package//\//%2F}"
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
}
image_refs=()
tag_and_push() {
local src="$1" img_name tag dst
img_name="${src%%:*}"
@@ -253,56 +233,35 @@ jobs:
echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst"
docker push "$dst"
image_refs+=("$dst")
done
}
cat > /tmp/goreleaser-artifacts.json <<'JSON'
${{ steps.goreleaser.outputs.artifacts }}
JSON
export -f tag_and_push resolve_tags
mapfile -t src_images < <(
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
)
for src in "${src_images[@]}"; do
tag_and_push "$src"
done
{
echo "images_markdown<<EOF"
if [[ ${#image_refs[@]} -eq 0 ]]; then
echo "_No GHCR images were pushed._"
else
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
done
fi
echo "EOF"
} >> "$GITHUB_OUTPUT"
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
grep '^ghcr.io/' | while read -r SRC; do
tag_and_push "$SRC"
done
- name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@v4
with:
name: release
path: dist/
retention-days: 7
- name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@v4
with:
name: linux-packages
path: dist/netbird_linux**
retention-days: 7
- name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@v4
with:
name: windows-packages
path: dist/netbird_windows**
retention-days: 7
- name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@v4
with:
name: macos-packages
@@ -311,8 +270,6 @@ jobs:
release_ui:
runs-on: ubuntu-latest
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps:
- name: Parse semver string
id: semver_parser
@@ -403,7 +360,6 @@ jobs:
if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@v4
with:
name: release-ui
@@ -412,8 +368,6 @@ jobs:
release_ui_darwin:
runs-on: macos-latest
outputs:
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV
@@ -448,258 +402,15 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@v4
with:
name: release-ui-darwin
path: dist/
retention-days: 3
test_windows_installer:
name: "Windows Installer / Build Test"
runs-on: windows-2022
needs: [release, release_ui]
strategy:
fail-fast: false
matrix:
include:
- arch: amd64
wintun_arch: amd64
- arch: arm64
wintun_arch: arm64
defaults:
run:
shell: powershell
env:
PackageWorkdir: netbird_windows_${{ matrix.arch }}
downloadPath: '${{ github.workspace }}\temp'
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- name: Checkout
uses: actions/checkout@v4
- name: Add 7-Zip to PATH
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts
uses: actions/download-artifact@v4
with:
name: release
path: release
- name: Download UI release artifacts
uses: actions/download-artifact@v4
with:
name: release-ui
path: release-ui
- name: Stage binaries into dist
run: |
$workdir = "dist\${{ env.PackageWorkdir }}"
New-Item -ItemType Directory -Force -Path $workdir | Out-Null
$client = Get-ChildItem -Recurse -Path release -Filter "netbird_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
$ui = Get-ChildItem -Recurse -Path release-ui -Filter "netbird-ui-windows_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
if (-not $client) { Write-Host "::error::client tarball not found for ${{ matrix.arch }}"; exit 1 }
if (-not $ui) { Write-Host "::error::ui tarball not found for ${{ matrix.arch }}"; exit 1 }
Write-Host "Client: $($client.FullName)"
Write-Host "UI: $($ui.FullName)"
tar -zvxf $client.FullName -C $workdir
tar -zvxf $ui.FullName -C $workdir
Get-ChildItem $workdir
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompress wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
uses: carlosperate/download-file-action@v2
id: download-mesa3d
if: matrix.arch == 'amd64'
with:
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
file-name: mesa3d.7z
location: ${{ env.downloadPath }}
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
- name: Move opengl32.dll into dist (amd64 only)
if: matrix.arch == 'amd64'
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: carlosperate/download-file-action@v2
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
file-name: envar_plugin.zip
location: ${{ github.workspace }}
- name: Extract EnVar plugin
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
uses: carlosperate/download-file-action@v2
if: matrix.arch == 'amd64'
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
file-name: ShellExecAsUser_amd64-Unicode.7z
location: ${{ github.workspace }}
- name: Extract ShellExecAsUser plugin (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
- name: Build NSIS installer
uses: joncloud/makensis-action@v3.3
with:
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
script-file: client/installer.nsis
arguments: "/V4 /DARCH=${{ matrix.arch }}"
env:
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
- name: Rename NSIS installer
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
- name: Install WiX
run: |
dotnet tool install --global wix --version 6.0.2
wix extension add WixToolset.Util.wixext/6.0.2
- name: Build MSI installer
env:
NETBIRD_VERSION: "${{ steps.semver_parser.outputs.fullversion }}"
run: wix build -arch ${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -ext WixToolset.Util.wixext -o netbird_installer_test_windows_${{ matrix.arch }}.msi .\client\netbird.wxs -d ProcessorArchitecture=${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -d ArchSuffix=${{ matrix.arch }}
- name: Upload installer artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: windows-installer-test-${{ matrix.arch }}
path: |
netbird_installer_test_windows_${{ matrix.arch }}.exe
netbird_installer_test_windows_${{ matrix.arch }}.msi
retention-days: 3
comment_release_artifacts:
name: Comment release artifacts
runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin]
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@v7
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const marker = '<!-- netbird-release-artifacts -->';
const { owner, repo } = context.repo;
const issue_number = context.payload.pull_request.number;
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
const artifactCell = (url, result) => {
if (url) return `[Download](${url})`;
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
};
const artifacts = [
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
];
const artifactRows = artifacts
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
.join('\n');
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
const body = [
marker,
'## Release artifacts',
'',
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
'',
'| Artifact | Link |',
'| --- | --- |',
artifactRows,
'',
'### GHCR images (amd64)',
ghcrImages,
'',
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
].join('\n');
const comments = await github.paginate(github.rest.issues.listComments, {
owner,
repo,
issue_number,
per_page: 100,
});
const previous = comments.find(comment =>
comment.user?.type === 'Bot' && comment.body?.includes(marker)
);
if (previous) {
await github.rest.issues.updateComment({
owner,
repo,
comment_id: previous.id,
body,
});
core.info(`Updated release artifacts comment ${previous.id}`);
} else {
const { data } = await github.rest.issues.createComment({
owner,
repo,
issue_number,
body,
});
core.info(`Created release artifacts comment ${data.id}`);
}
trigger_signer:
runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin, test_windows_installer]
needs: [release, release_ui, release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/')
steps:
- name: Trigger binaries sign pipelines

View File

@@ -9,8 +9,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true
# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short
# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref.
jobs:
trigger_sync_tag:
runs-on: ubuntu-latest
@@ -23,29 +21,3 @@ jobs:
repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_android_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger android-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/android-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_ios_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger ios-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/ios-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'

2
.gitignore vendored
View File

@@ -33,3 +33,5 @@ infrastructure_files/setup-*.env
vendor/
/netbird
client/netbird-electron/
management/server/types/testdata/comparison/
management/server/types/testdata/*.json

View File

@@ -58,11 +58,6 @@ linters:
govet:
enable:
- nilness
disable:
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
# directives but cannot perform the rewrite due to generic type
# parameter inference limitations in the Go inliner.
- inline
enable-all: false
revive:
rules:

View File

@@ -5,7 +5,7 @@ GOLANGCI_LINT := $(shell pwd)/bin/golangci-lint
$(GOLANGCI_LINT):
@echo "Installing golangci-lint..."
@mkdir -p ./bin
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/v2/cmd/golangci-lint@latest
@GOBIN=$(shell pwd)/bin go install github.com/golangci/golangci-lint/cmd/golangci-lint@latest
# Lint only changed files (fast, for pre-push)
lint: $(GOLANGCI_LINT)

View File

@@ -17,7 +17,6 @@ ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENABLE_CAPTURE="false" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -23,7 +23,6 @@ ENV \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \
NB_ENABLE_CAPTURE="false" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -8,7 +8,6 @@ import (
"os"
"slices"
"sync"
"time"
"golang.org/x/exp/maps"
@@ -16,7 +15,6 @@ import (
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/listener"
"github.com/netbirdio/netbird/client/internal/peer"
@@ -28,7 +26,6 @@ import (
"github.com/netbirdio/netbird/formatter"
"github.com/netbirdio/netbird/route"
"github.com/netbirdio/netbird/shared/management/domain"
types "github.com/netbirdio/netbird/upload-server/types"
)
// ConnectionListener export internal Listener for mobile
@@ -71,30 +68,7 @@ type Client struct {
uiVersion string
networkChangeListener listener.NetworkChangeListener
stateMu sync.RWMutex
connectClient *internal.ConnectClient
config *profilemanager.Config
cacheDir string
}
func (c *Client) setState(cfg *profilemanager.Config, cacheDir string, cc *internal.ConnectClient) {
c.stateMu.Lock()
defer c.stateMu.Unlock()
c.config = cfg
c.cacheDir = cacheDir
c.connectClient = cc
}
func (c *Client) stateSnapshot() (*profilemanager.Config, string, *internal.ConnectClient) {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.config, c.cacheDir, c.connectClient
}
func (c *Client) getConnectClient() *internal.ConnectClient {
c.stateMu.RLock()
defer c.stateMu.RUnlock()
return c.connectClient
}
// NewClient instantiate a new Client
@@ -119,7 +93,6 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client with config: %s, state: %s", cfgFile, stateFile)
@@ -151,9 +124,8 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
@@ -163,7 +135,6 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
cfgFile := platformFiles.ConfigurationFilePath()
stateFile := platformFiles.StateFilePath()
cacheDir := platformFiles.CacheDir()
log.Infof("Starting client without login with config: %s, state: %s", cfgFile, stateFile)
@@ -186,9 +157,8 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
// todo do not throw error in case of cancelled context
ctx = internal.CtxInitState(ctx)
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
c.setState(cfg, cacheDir, connectClient)
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile, cacheDir)
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
}
// Stop the internal client and free the resources
@@ -203,12 +173,11 @@ func (c *Client) Stop() {
}
func (c *Client) RenewTun(fd int) error {
cc := c.getConnectClient()
if cc == nil {
if c.connectClient == nil {
return fmt.Errorf("engine not running")
}
e := cc.Engine()
e := c.connectClient.Engine()
if e == nil {
return fmt.Errorf("engine not initialized")
}
@@ -216,73 +185,6 @@ func (c *Client) RenewTun(fd int) error {
return e.RenewTun(fd)
}
// DebugBundle generates a debug bundle, uploads it, and returns the upload key.
// It works both with and without a running engine.
func (c *Client) DebugBundle(platformFiles PlatformFiles, anonymize bool) (string, error) {
cfg, cacheDir, cc := c.stateSnapshot()
// If the engine hasn't been started, load config from disk
if cfg == nil {
var err error
cfg, err = profilemanager.UpdateOrCreateConfig(profilemanager.ConfigInput{
ConfigPath: platformFiles.ConfigurationFilePath(),
})
if err != nil {
return "", fmt.Errorf("load config: %w", err)
}
cacheDir = platformFiles.CacheDir()
}
deps := debug.GeneratorDependencies{
InternalConfig: cfg,
StatusRecorder: c.recorder,
TempDir: cacheDir,
}
if cc != nil {
resp, err := cc.GetLatestSyncResponse()
if err != nil {
log.Warnf("get latest sync response: %v", err)
}
deps.SyncResponse = resp
if e := cc.Engine(); e != nil {
if cm := e.GetClientMetrics(); cm != nil {
deps.ClientMetrics = cm
}
}
}
bundleGenerator := debug.NewBundleGenerator(
deps,
debug.BundleConfig{
Anonymize: anonymize,
IncludeSystemInfo: true,
},
)
path, err := bundleGenerator.Generate()
if err != nil {
return "", fmt.Errorf("generate debug bundle: %w", err)
}
defer func() {
if err := os.Remove(path); err != nil {
log.Errorf("failed to remove debug bundle file: %v", err)
}
}()
uploadCtx, cancel := context.WithTimeout(context.Background(), 2*time.Minute)
defer cancel()
key, err := debug.UploadDebugBundle(uploadCtx, types.DefaultBundleURL, cfg.ManagementURL.String(), path)
if err != nil {
return "", fmt.Errorf("upload debug bundle: %w", err)
}
log.Infof("debug bundle uploaded with key %s", key)
return key, nil
}
// SetTraceLogLevel configure the logger to trace level
func (c *Client) SetTraceLogLevel() {
log.SetLevel(log.TraceLevel)
@@ -312,13 +214,12 @@ func (c *Client) PeersList() *PeerInfoArray {
}
func (c *Client) Networks() *NetworkArray {
cc := c.getConnectClient()
if cc == nil {
if c.connectClient == nil {
log.Error("not connected")
return nil
}
engine := cc.Engine()
engine := c.connectClient.Engine()
if engine == nil {
log.Error("could not get engine")
return nil
@@ -399,7 +300,7 @@ func (c *Client) toggleRoute(command routeCommand) error {
}
func (c *Client) getRouteManager() (routemanager.Manager, error) {
client := c.getConnectClient()
client := c.connectClient
if client == nil {
return nil, fmt.Errorf("not connected")
}

View File

@@ -7,5 +7,4 @@ package android
type PlatformFiles interface {
ConfigurationFilePath() string
StateFilePath() string
CacheDir() string
}

View File

@@ -1,196 +0,0 @@
package cmd
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"github.com/hashicorp/go-multierror"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util/capture"
)
var captureCmd = &cobra.Command{
Use: "capture",
Short: "Capture packets on the WireGuard interface",
Long: `Captures decrypted packets flowing through the WireGuard interface.
Default output is human-readable text. Use --pcap or --output for pcap binary.
Requires --enable-capture to be set at service install or reconfigure time.
Examples:
netbird debug capture
netbird debug capture host 100.64.0.1 and port 443
netbird debug capture tcp
netbird debug capture icmp
netbird debug capture src host 10.0.0.1 and dst port 80
netbird debug capture -o capture.pcap
netbird debug capture --pcap | tshark -r -
netbird debug capture --pcap | tcpdump -r - -n`,
Args: cobra.ArbitraryArgs,
RunE: runCapture,
}
func init() {
debugCmd.AddCommand(captureCmd)
captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length")
captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)")
captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)")
captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)")
captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
}
func runCapture(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
cmd.PrintErrf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
req, err := buildCaptureRequest(cmd, args)
if err != nil {
return err
}
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
stream, err := client.StartCapture(ctx, req)
if err != nil {
return handleCaptureError(err)
}
// First Recv is the empty acceptance message from the server. If the
// device is unavailable (kernel WG, not connected, capture disabled),
// the server returns an error instead.
if _, err := stream.Recv(); err != nil {
return handleCaptureError(err)
}
out, cleanup, err := captureOutput(cmd)
if err != nil {
return err
}
if req.TextOutput {
cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n")
} else {
cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n")
}
streamErr := streamCapture(ctx, cmd, stream, out)
cleanupErr := cleanup()
if streamErr != nil {
return streamErr
}
return cleanupErr
}
func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) {
req := &proto.StartCaptureRequest{}
if len(args) > 0 {
expr := strings.Join(args, " ")
if _, err := capture.ParseFilter(expr); err != nil {
return nil, fmt.Errorf("invalid filter: %w", err)
}
req.FilterExpr = expr
}
if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 {
req.SnapLen = snap
}
if d, _ := cmd.Flags().GetDuration("duration"); d != 0 {
if d < 0 {
return nil, fmt.Errorf("duration must not be negative")
}
req.Duration = durationpb.New(d)
}
req.Verbose, _ = cmd.Flags().GetBool("verbose")
req.Ascii, _ = cmd.Flags().GetBool("ascii")
outPath, _ := cmd.Flags().GetString("output")
forcePcap, _ := cmd.Flags().GetBool("pcap")
req.TextOutput = !forcePcap && outPath == ""
return req, nil
}
func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error {
for {
pkt, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.PrintErrf("\nCapture stopped.\n")
return nil //nolint:nilerr // user interrupted
}
if err == io.EOF {
cmd.PrintErrf("\nCapture finished.\n")
return nil
}
return handleCaptureError(err)
}
if _, err := out.Write(pkt.GetData()); err != nil {
return fmt.Errorf("write output: %w", err)
}
}
}
// captureOutput returns the writer for capture data and a cleanup function
// that finalizes the file. Errors from the cleanup must be propagated.
func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) {
outPath, _ := cmd.Flags().GetString("output")
if outPath == "" {
return os.Stdout, func() error { return nil }, nil
}
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
if err != nil {
return nil, nil, fmt.Errorf("create output file: %w", err)
}
tmpPath := f.Name()
return f, func() error {
var merr *multierror.Error
if err := f.Close(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err))
}
fi, statErr := os.Stat(tmpPath)
if statErr != nil || fi.Size() == 0 {
if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) {
merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr))
}
return nberrors.FormatErrorOrNil(merr)
}
if err := os.Rename(tmpPath, outPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err))
return nberrors.FormatErrorOrNil(merr)
}
cmd.PrintErrf("Wrote %s\n", outPath)
return nberrors.FormatErrorOrNil(merr)
}, nil
}
func handleCaptureError(err error) error {
if s, ok := status.FromError(err); ok {
return fmt.Errorf("%s", s.Message())
}
return err
}

View File

@@ -9,7 +9,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug"
@@ -240,50 +239,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}()
}
captureStarted := false
if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture {
captureTimeout := duration + 30*time.Second
const maxBundleCapture = 10 * time.Minute
if captureTimeout > maxBundleCapture {
captureTimeout = maxBundleCapture
}
_, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{
Timeout: durationpb.New(captureTimeout),
})
if err != nil {
cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message())
} else {
captureStarted = true
cmd.Println("Packet capture started.")
// Safety: always stop on exit, even if the normal stop below runs too.
defer func() {
if captureStarted {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
}
}
}()
}
}
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr
}
cmd.Println("\nDuration completed")
if captureStarted {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
} else {
captureStarted = false
cmd.Println("Packet capture stopped.")
}
}
if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
@@ -456,5 +416,4 @@ func init() {
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle")
}

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus"
"github.com/spf13/cobra"
"golang.org/x/term"
"google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status"
@@ -24,7 +23,6 @@ import (
func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
}
@@ -258,7 +256,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
}
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil {
@@ -326,7 +324,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
}
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil {
@@ -336,7 +334,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return &tokenInfo, nil
}
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
@@ -350,12 +348,6 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
verificationURIComplete + " " + codeMsg)
}
if showQR {
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
printQRCode(f, verificationURIComplete)
}
}
cmd.Println("")
if !noBrowser {

View File

@@ -1,25 +0,0 @@
package cmd
import (
"io"
"github.com/mdp/qrterminal/v3"
)
// printQRCode prints a QR code for the given URL to the writer.
// Called only when the user explicitly requests QR output via --qr.
func printQRCode(w io.Writer, url string) {
if url == "" {
return
}
qrterminal.GenerateWithConfig(url, qrterminal.Config{
Level: qrterminal.M,
Writer: w,
HalfBlocks: true,
BlackChar: qrterminal.BLACK_BLACK,
WhiteChar: qrterminal.WHITE_WHITE,
BlackWhiteChar: qrterminal.BLACK_WHITE,
WhiteBlackChar: qrterminal.WHITE_BLACK,
QuietZone: qrterminal.QUIET_ZONE,
})
}

View File

@@ -1,26 +0,0 @@
package cmd
import (
"bytes"
"testing"
)
func TestPrintQRCode_EmptyURL(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "")
if buf.Len() != 0 {
t.Error("expected no output for empty URL")
}
}
func TestPrintQRCode_WritesOutput(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "https://example.com/auth")
if buf.Len() == 0 {
t.Error("expected QR code output for non-empty URL")
}
}

View File

@@ -75,8 +75,6 @@ var (
mtu uint16
profilesDisabled bool
updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool
rootCmd = &cobra.Command{
Use: "netbird",

View File

@@ -44,14 +44,10 @@ func init() {
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
serviceEnvDesc := `Sets extra environment variables for the service. ` +
`You can specify a comma-separated list of KEY=VALUE pairs. ` +
`New keys are merged with previously saved env vars; existing keys are overwritten. ` +
`Use --service-env "" to clear all saved env vars. ` +
`E.g. --service-env NB_LOG_LEVEL=debug,CUSTOM_VAR=value`
installCmd.Flags().StringSliceVar(&serviceEnvVars, "service-env", nil, serviceEnvDesc)

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
}
}
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled)
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled)
if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err)
}

View File

@@ -59,14 +59,6 @@ func buildServiceArguments() []string {
args = append(args, "--disable-update-settings")
}
if captureEnabled {
args = append(args, "--enable-capture")
}
if networksDisabled {
args = append(args, "--disable-networks")
}
return args
}

View File

@@ -28,8 +28,6 @@ type serviceParams struct {
LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
EnableCapture bool `json:"enable_capture,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
}
@@ -80,13 +78,11 @@ func currentServiceParams() *serviceParams {
LogFiles: logFiles,
DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled,
EnableCapture: captureEnabled,
DisableNetworks: networksDisabled,
}
if len(serviceEnvVars) > 0 {
parsed, err := parseServiceEnvVars(serviceEnvVars)
if err == nil {
if err == nil && len(parsed) > 0 {
params.ServiceEnvVars = parsed
}
}
@@ -146,50 +142,31 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
updateSettingsDisabled = params.DisableUpdateSettings
}
if !serviceCmd.PersistentFlags().Changed("enable-capture") {
captureEnabled = params.EnableCapture
}
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
networksDisabled = params.DisableNetworks
}
applyServiceEnvParams(cmd, params)
}
// applyServiceEnvParams merges saved service environment variables.
// If --service-env was explicitly set with values, explicit values win on key
// conflict but saved keys not in the explicit set are carried over.
// If --service-env was explicitly set to empty, all saved env vars are cleared.
// If --service-env was explicitly set, explicit values win on key conflict
// but saved keys not in the explicit set are carried over.
// If --service-env was not set, saved env vars are used entirely.
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
if !cmd.Flags().Changed("service-env") {
if len(params.ServiceEnvVars) > 0 {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
}
if len(params.ServiceEnvVars) == 0 {
return
}
// Flag was explicitly set: parse what the user provided.
if !cmd.Flags().Changed("service-env") {
// No explicit env vars: rebuild serviceEnvVars from saved params.
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
return
}
// Explicit env vars were provided: merge saved values underneath.
explicit, err := parseServiceEnvVars(serviceEnvVars)
if err != nil {
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
return
}
// If the user passed an empty value (e.g. --service-env ""), clear all
// saved env vars rather than merging.
if len(explicit) == 0 {
serviceEnvVars = nil
return
}
if len(params.ServiceEnvVars) == 0 {
return
}
// Merge saved values underneath explicit ones.
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
maps.Copy(merged, params.ServiceEnvVars)
maps.Copy(merged, explicit) // explicit wins on conflict

View File

@@ -327,41 +327,6 @@ func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
}
func TestApplyServiceEnvParams_ExplicitEmptyClears(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
cmd := &cobra.Command{}
cmd.Flags().StringSlice("service-env", nil, "")
require.NoError(t, cmd.Flags().Set("service-env", ""))
saved := &serviceParams{
ServiceEnvVars: map[string]string{"OLD_VAR": "should_be_cleared"},
}
applyServiceEnvParams(cmd, saved)
assert.Nil(t, serviceEnvVars, "explicit empty --service-env should clear all saved env vars")
}
func TestCurrentServiceParams_EmptyEnvVarsAfterParse(t *testing.T) {
origServiceEnvVars := serviceEnvVars
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
// Simulate --service-env "" which produces [""] in the slice.
serviceEnvVars = []string{""}
params := currentServiceParams()
// After parsing, the empty string is skipped, resulting in an empty map.
// The map should still be set (not nil) so it overwrites saved values.
assert.NotNil(t, params.ServiceEnvVars, "empty env vars should produce empty map, not nil")
assert.Empty(t, params.ServiceEnvVars, "no valid env vars should be parsed from empty string")
}
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
// added to serviceParams but not wired into these functions, this test fails.
@@ -535,8 +500,6 @@ func fieldToGlobalVar(field string) string {
"LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled",
"EnableCapture": "captureEnabled",
"DisableNetworks": "networksDisabled",
"ServiceEnvVars": "serviceEnvVars",
}
if v, ok := m[field]; ok {

View File

@@ -13,8 +13,6 @@ import (
"github.com/netbirdio/management-integrations/integrations"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/update_channel"
"github.com/netbirdio/netbird/management/internals/modules/peers"
@@ -102,16 +100,9 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
jobManager := job.NewJobManager(nil, store, peersmanager)
ctx := context.Background()
iv, _ := integrations.NewIntegratedValidator(context.Background(), peersmanager, settingsManagerMock, eventStore)
cacheStore, err := nbcache.NewStore(ctx, 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
t.Fatal(err)
}
iv, _ := integrations.NewIntegratedValidator(ctx, peersmanager, settingsManagerMock, eventStore, cacheStore)
metrics, err := telemetry.NewDefaultAppMetrics(ctx)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
settingsMockManager := settings.NewMockManager(ctrl)
@@ -122,11 +113,12 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
Return(&types.Settings{}, nil).
AnyTimes()
ctx := context.Background()
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := mgmt.NewAccountRequestBuffer(ctx, store)
networkMapController := controller.NewController(ctx, store, metrics, updateManager, requestBuffer, mgmt.MockIntegratedValidator{}, settingsMockManager, "netbird.cloud", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersmanager), config)
accountManager, err := mgmt.BuildManager(ctx, config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false, cacheStore)
accountManager, err := mgmt.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
if err != nil {
t.Fatal(err)
}
@@ -135,7 +127,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil {
t.Fatal(err)
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
t.Fatal(err)
}
@@ -160,7 +152,7 @@ func startClientDaemon(
s := grpc.NewServer()
server := client.New(ctx,
"", "", false, false, false, false)
"", "", false, false)
if err := server.Start(); err != nil {
t.Fatal(err)
}

View File

@@ -39,9 +39,6 @@ const (
noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login"
showQRFlag = "qr"
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
profileNameFlag = "profile"
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
)
@@ -51,7 +48,6 @@ var (
dnsLabels []string
dnsLabelsValidated domain.List
noBrowser bool
showQR bool
profileName string
configPath string
@@ -84,7 +80,6 @@ func init() {
)
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")

View File

@@ -1,65 +0,0 @@
package embed
import (
"io"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/capture"
)
// CaptureOptions configures a packet capture session.
type CaptureOptions struct {
// Output receives pcap-formatted data. Nil disables pcap output.
Output io.Writer
// TextOutput receives human-readable packet summaries. Nil disables text output.
TextOutput io.Writer
// Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443").
// Empty captures all packets.
Filter string
// Verbose adds seq/ack, TTL, window, and total length to text output.
Verbose bool
// ASCII dumps transport payload as printable ASCII after each packet line.
ASCII bool
}
// CaptureStats reports capture session counters.
type CaptureStats struct {
Packets int64
Bytes int64
Dropped int64
}
// CaptureSession represents an active packet capture. Call Stop to end the
// capture and flush buffered packets.
type CaptureSession struct {
sess *capture.Session
engine *internal.Engine
}
// Stop ends the capture, flushes remaining packets, and detaches from the device.
// Safe to call multiple times.
func (cs *CaptureSession) Stop() {
if cs.engine != nil {
_ = cs.engine.SetCapture(nil)
cs.engine = nil
}
if cs.sess != nil {
cs.sess.Stop()
}
}
// Stats returns current capture counters.
func (cs *CaptureSession) Stats() CaptureStats {
s := cs.sess.Stats()
return CaptureStats{
Packets: s.Packets,
Bytes: s.Bytes,
Dropped: s.Dropped,
}
}
// Done returns a channel that is closed when the capture's writer goroutine
// has fully exited and all buffered packets have been flushed.
func (cs *CaptureSession) Done() <-chan struct{} {
return cs.sess.Done()
}

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/capture"
)
var (
@@ -66,7 +65,7 @@ type Options struct {
PrivateKey string
// ManagementURL overrides the default management server URL
ManagementURL string
// PreSharedKey is the pre-shared key for the tunnel interface
// PreSharedKey is the pre-shared key for the WireGuard interface
PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer
@@ -82,9 +81,9 @@ type Options struct {
DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers
BlockInbound bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int
// MTU is the MTU for the tunnel interface.
// MTU is the MTU for the WireGuard interface.
// Valid values are in the range 576..8192 bytes.
// If non-nil, this value overrides any value stored in the config file.
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
@@ -470,52 +469,6 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
}
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.
func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
var matcher capture.Matcher
if opts.Filter != "" {
m, err := capture.ParseFilter(opts.Filter)
if err != nil {
return nil, fmt.Errorf("parse filter: %w", err)
}
matcher = m
}
sess, err := capture.NewSession(capture.Options{
Output: opts.Output,
TextOutput: opts.TextOutput,
Matcher: matcher,
Verbose: opts.Verbose,
ASCII: opts.ASCII,
})
if err != nil {
return nil, fmt.Errorf("create capture session: %w", err)
}
if err := engine.SetCapture(sess); err != nil {
sess.Stop()
return nil, fmt.Errorf("set capture: %w", err)
}
return &CaptureSession{sess: sess, engine: engine}, nil
}
// StopCapture stops the active capture session if one is running.
func (c *Client) StopCapture() error {
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetCapture(nil)
}
// getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available.

View File

@@ -56,13 +56,6 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogg
return createUserspaceFirewall(iface, nil, disableServerRoutes, flowLogger, mtu)
}
// Native firewall handles packet filtering, but the userspace WireGuard bind
// needs a device filter for DNS interception hooks. Install a minimal
// hooks-only filter that passes all traffic through to the kernel firewall.
if err := iface.SetFilter(&uspfilter.HooksFilter{}); err != nil {
log.Warnf("failed to set hooks filter, DNS via memory hooks will not work: %v", err)
}
return fm, nil
}

View File

@@ -1,11 +0,0 @@
// Package firewalld integrates with the firewalld daemon so NetBird can place
// its wg interface into firewalld's "trusted" zone. This is required because
// firewalld's nftables chains are created with NFT_CHAIN_OWNER on recent
// versions, which returns EPERM to any other process that tries to insert
// rules into them. The workaround mirrors what Tailscale does: let firewalld
// itself add the accept rules to its own chains by trusting the interface.
package firewalld
// TrustedZone is the firewalld zone name used for interfaces whose traffic
// should bypass firewalld filtering.
const TrustedZone = "trusted"

View File

@@ -1,260 +0,0 @@
//go:build linux
package firewalld
import (
"context"
"errors"
"fmt"
"os/exec"
"strings"
"sync"
"time"
"github.com/godbus/dbus/v5"
log "github.com/sirupsen/logrus"
)
const (
dbusDest = "org.fedoraproject.FirewallD1"
dbusPath = "/org/fedoraproject/FirewallD1"
dbusRootIface = "org.fedoraproject.FirewallD1"
dbusZoneIface = "org.fedoraproject.FirewallD1.zone"
errZoneAlreadySet = "ZONE_ALREADY_SET"
errAlreadyEnabled = "ALREADY_ENABLED"
errUnknownIface = "UNKNOWN_INTERFACE"
errNotEnabled = "NOT_ENABLED"
// callTimeout bounds each individual DBus or firewall-cmd invocation.
// A fresh context is created for each call so a slow DBus probe can't
// exhaust the deadline before the firewall-cmd fallback gets to run.
callTimeout = 3 * time.Second
)
var (
errDBusUnavailable = errors.New("firewalld dbus unavailable")
// trustLogOnce ensures the "added to trusted zone" message is logged at
// Info level only for the first successful add per process; repeat adds
// from other init paths are quieter.
trustLogOnce sync.Once
parentCtxMu sync.RWMutex
parentCtx context.Context = context.Background()
)
// SetParentContext installs a parent context whose cancellation aborts any
// in-flight TrustInterface call. It does not affect UntrustInterface, which
// always uses a fresh Background-rooted timeout so cleanup can still run
// during engine shutdown when the engine context is already cancelled.
func SetParentContext(ctx context.Context) {
parentCtxMu.Lock()
parentCtx = ctx
parentCtxMu.Unlock()
}
func getParentContext() context.Context {
parentCtxMu.RLock()
defer parentCtxMu.RUnlock()
return parentCtx
}
// TrustInterface places iface into firewalld's trusted zone if firewalld is
// running. It is idempotent and best-effort: errors are returned so callers
// can log, but a non-running firewalld is not an error. Only the first
// successful call per process logs at Info. Respects the parent context set
// via SetParentContext so startup-time cancellation unblocks it.
func TrustInterface(iface string) error {
parent := getParentContext()
if !isRunning(parent) {
return nil
}
if err := addTrusted(parent, iface); err != nil {
return fmt.Errorf("add %s to firewalld trusted zone: %w", iface, err)
}
trustLogOnce.Do(func() {
log.Infof("added %s to firewalld trusted zone", iface)
})
log.Debugf("firewalld: ensured %s is in trusted zone", iface)
return nil
}
// UntrustInterface removes iface from firewalld's trusted zone if firewalld
// is running. Idempotent. Uses a Background-rooted timeout so it still runs
// during shutdown after the engine context has been cancelled.
func UntrustInterface(iface string) error {
if !isRunning(context.Background()) {
return nil
}
if err := removeTrusted(context.Background(), iface); err != nil {
return fmt.Errorf("remove %s from firewalld trusted zone: %w", iface, err)
}
return nil
}
func newCallContext(parent context.Context) (context.Context, context.CancelFunc) {
return context.WithTimeout(parent, callTimeout)
}
func isRunning(parent context.Context) bool {
ctx, cancel := newCallContext(parent)
ok, err := isRunningDBus(ctx)
cancel()
if err == nil {
return ok
}
if errors.Is(err, errDBusUnavailable) || errors.Is(err, context.DeadlineExceeded) {
ctx, cancel = newCallContext(parent)
defer cancel()
return isRunningCLI(ctx)
}
return false
}
func addTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := addDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus add failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return addCLI(ctx, iface)
}
func removeTrusted(parent context.Context, iface string) error {
ctx, cancel := newCallContext(parent)
err := removeDBus(ctx, iface)
cancel()
if err == nil {
return nil
}
if !errors.Is(err, errDBusUnavailable) {
log.Debugf("firewalld: dbus remove failed, falling back to firewall-cmd: %v", err)
}
ctx, cancel = newCallContext(parent)
defer cancel()
return removeCLI(ctx, iface)
}
func isRunningDBus(ctx context.Context) (bool, error) {
conn, err := dbus.SystemBus()
if err != nil {
return false, fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
var zone string
if err := obj.CallWithContext(ctx, dbusRootIface+".getDefaultZone", 0).Store(&zone); err != nil {
return false, fmt.Errorf("firewalld getDefaultZone: %w", err)
}
return true, nil
}
func isRunningCLI(ctx context.Context) bool {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return false
}
return exec.CommandContext(ctx, "firewall-cmd", "--state").Run() == nil
}
func addDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".addInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errAlreadyEnabled) {
return nil
}
if dbusErrContains(call.Err, errZoneAlreadySet) {
move := obj.CallWithContext(ctx, dbusZoneIface+".changeZoneOfInterface", 0, TrustedZone, iface)
if move.Err != nil {
return fmt.Errorf("firewalld changeZoneOfInterface: %w", move.Err)
}
return nil
}
return fmt.Errorf("firewalld addInterface: %w", call.Err)
}
func removeDBus(ctx context.Context, iface string) error {
conn, err := dbus.SystemBus()
if err != nil {
return fmt.Errorf("%w: %v", errDBusUnavailable, err)
}
obj := conn.Object(dbusDest, dbusPath)
call := obj.CallWithContext(ctx, dbusZoneIface+".removeInterface", 0, TrustedZone, iface)
if call.Err == nil {
return nil
}
if dbusErrContains(call.Err, errUnknownIface) || dbusErrContains(call.Err, errNotEnabled) {
return nil
}
return fmt.Errorf("firewalld removeInterface: %w", call.Err)
}
func addCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
// --change-interface (no --permanent) binds the interface for the
// current runtime only; we do not want membership to persist across
// reboots because netbird re-asserts it on every startup.
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--change-interface="+iface,
).CombinedOutput()
if err != nil {
return fmt.Errorf("firewall-cmd change-interface: %w: %s", err, strings.TrimSpace(string(out)))
}
return nil
}
func removeCLI(ctx context.Context, iface string) error {
if _, err := exec.LookPath("firewall-cmd"); err != nil {
return fmt.Errorf("firewall-cmd not available: %w", err)
}
out, err := exec.CommandContext(ctx,
"firewall-cmd", "--zone="+TrustedZone, "--remove-interface="+iface,
).CombinedOutput()
if err != nil {
msg := strings.TrimSpace(string(out))
if strings.Contains(msg, errUnknownIface) || strings.Contains(msg, errNotEnabled) {
return nil
}
return fmt.Errorf("firewall-cmd remove-interface: %w: %s", err, msg)
}
return nil
}
func dbusErrContains(err error, code string) bool {
if err == nil {
return false
}
var de dbus.Error
if errors.As(err, &de) {
for _, b := range de.Body {
if s, ok := b.(string); ok && strings.Contains(s, code) {
return true
}
}
}
return strings.Contains(err.Error(), code)
}

View File

@@ -1,49 +0,0 @@
//go:build linux
package firewalld
import (
"errors"
"testing"
"github.com/godbus/dbus/v5"
)
func TestDBusErrContains(t *testing.T) {
tests := []struct {
name string
err error
code string
want bool
}{
{"nil error", nil, errZoneAlreadySet, false},
{"plain error match", errors.New("ZONE_ALREADY_SET: wt0"), errZoneAlreadySet, true},
{"plain error miss", errors.New("something else"), errZoneAlreadySet, false},
{
"dbus.Error body match",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"ZONE_ALREADY_SET: wt0"}},
errZoneAlreadySet,
true,
},
{
"dbus.Error body miss",
dbus.Error{Name: "org.fedoraproject.FirewallD1.Exception", Body: []any{"INVALID_INTERFACE"}},
errAlreadyEnabled,
false,
},
{
"dbus.Error non-string body falls back to Error()",
dbus.Error{Name: "x", Body: []any{123}},
"x",
true,
},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
got := dbusErrContains(tc.err, tc.code)
if got != tc.want {
t.Fatalf("dbusErrContains(%v, %q) = %v; want %v", tc.err, tc.code, got, tc.want)
}
})
}
}

View File

@@ -1,25 +0,0 @@
//go:build !linux
package firewalld
import "context"
// SetParentContext is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func SetParentContext(context.Context) {
// intentionally empty: firewalld is a Linux-only daemon
}
// TrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func TrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}
// UntrustInterface is a no-op on non-Linux platforms because firewalld only
// runs on Linux.
func UntrustInterface(string) error {
// intentionally empty: firewalld is a Linux-only daemon
return nil
}

View File

@@ -21,10 +21,6 @@ const (
// rules chains contains the effective ACL rules
chainNameInputRules = "NETBIRD-ACL-INPUT"
// mangleFwdKey is the entries map key for mangle FORWARD guard rules that prevent
// external DNAT from bypassing ACL rules.
mangleFwdKey = "MANGLE-FORWARD"
)
type aclEntries map[string][][]string
@@ -278,12 +274,6 @@ func (m *aclManager) cleanChains() error {
}
}
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.DeleteIfExists(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to delete mangle FORWARD guard rule: %v, %s", rule, err)
}
}
for _, ipsetName := range m.ipsetStore.ipsetNames() {
if err := m.flushIPSet(ipsetName); err != nil {
if errors.Is(err, ipset.ErrSetNotExist) {
@@ -313,10 +303,6 @@ func (m *aclManager) createDefaultChains() error {
}
for chainName, rules := range m.entries {
// mangle FORWARD guard rules are handled separately below
if chainName == mangleFwdKey {
continue
}
for _, rule := range rules {
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
log.Debugf("failed to create input chain jump rule: %s", err)
@@ -336,13 +322,6 @@ func (m *aclManager) createDefaultChains() error {
}
clear(m.optionalEntries)
// Insert mangle FORWARD guard rules to prevent external DNAT bypass.
for _, rule := range m.entries[mangleFwdKey] {
if err := m.iptablesClient.AppendUnique(tableMangle, chainFORWARD, rule...); err != nil {
log.Errorf("failed to add mangle FORWARD guard rule: %v", err)
}
}
return nil
}
@@ -364,22 +343,6 @@ func (m *aclManager) seedInitialEntries() {
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
// Mangle FORWARD guard: when external DNAT redirects traffic from the wg interface, it
// traverses FORWARD instead of INPUT, bypassing ACL rules. ACCEPT rules in filter FORWARD
// can be inserted above ours. Mangle runs before filter, so these guard rules enforce the
// ACL mark check where it cannot be overridden.
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED",
"-j", "ACCEPT",
})
m.appendToEntries(mangleFwdKey, []string{
"-i", m.wgIface.Name(),
"-m", "conntrack", "--ctstate", "DNAT",
"-m", "mark", "!", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
"-j", "DROP",
})
}
func (m *aclManager) seedInitialOptionalEntries() {

View File

@@ -12,7 +12,6 @@ import (
log "github.com/sirupsen/logrus"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -87,12 +86,6 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
}
// Trust after all fatal init steps so a later failure doesn't leave the
// interface in firewalld's trusted zone without a corresponding Close.
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
// persist early to ensure cleanup of chains
go func() {
if err := stateManager.PersistState(context.Background()); err != nil {
@@ -198,12 +191,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
merr = multierror.Append(merr, fmt.Errorf("reset router: %w", err))
}
// Appending to merr intentionally blocks DeleteState below so ShutdownState
// stays persisted and the crash-recovery path retries firewalld cleanup.
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
// attempt to delete state only if all other operations succeeded
if merr == nil {
if err := stateManager.DeleteState(&ShutdownState{}); err != nil {
@@ -230,11 +217,6 @@ func (m *Manager) AllowNetbird() error {
if err != nil {
return fmt.Errorf("allow netbird interface traffic: %w", err)
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}
@@ -382,6 +364,28 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return nil
}
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// RemoveTProxyRule removes TPROXY redirect rules by ID.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveTProxyRule(ruleID)
}
// AddUDPInspectionHook is a no-op for iptables (kernel-mode firewall has no userspace packet hooks).
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
// RemoveUDPInspectionHook is a no-op for iptables.
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
func (m *Manager) initNoTrackChain() error {
if err := m.cleanupNoTrackChain(); err != nil {
log.Debugf("cleanup notrack chain: %v", err)

View File

@@ -89,6 +89,8 @@ type router struct {
stateManager *statemanager.Manager
ipFwdState *ipfwdstate.IPForwardingState
tproxyRules []tproxyRuleEntry
}
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper, mtu uint16) (*router, error) {
@@ -1109,3 +1111,92 @@ func (r *router) addPrefixToIPSet(name string, prefix netip.Prefix) error {
func (r *router) destroyIPSet(name string) error {
return ipset.Destroy(name)
}
// AddTProxyRule adds iptables nat PREROUTING REDIRECT rules for transparent proxy interception.
// Traffic from sources on dstPorts arriving on the WG interface is redirected
// to the transparent proxy listener on redirectPort.
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
portStr := fmt.Sprintf("%d", redirectPort)
for _, proto := range []string{"tcp", "udp"} {
srcSpecs := r.buildSourceSpecs(sources)
for _, srcSpec := range srcSpecs {
if len(dstPorts) == 0 {
rule := append(srcSpec,
"-i", r.wgIface.Name(),
"-p", proto,
"-j", "REDIRECT",
"--to-ports", portStr,
)
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
return fmt.Errorf("add redirect rule %s/%s: %w", ruleID, proto, err)
}
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
ruleID: ruleID,
table: tableNat,
chain: chainRTRDR,
spec: rule,
})
} else {
for _, port := range dstPorts {
rule := append(srcSpec,
"-i", r.wgIface.Name(),
"-p", proto,
"--dport", fmt.Sprintf("%d", port),
"-j", "REDIRECT",
"--to-ports", portStr,
)
if err := r.iptablesClient.AppendUnique(tableNat, chainRTRDR, rule...); err != nil {
return fmt.Errorf("add redirect rule %s/%s/%d: %w", ruleID, proto, port, err)
}
r.tproxyRules = append(r.tproxyRules, tproxyRuleEntry{
ruleID: ruleID,
table: tableNat,
chain: chainRTRDR,
spec: rule,
})
}
}
}
}
return nil
}
// RemoveTProxyRule removes all iptables REDIRECT rules for the given ruleID.
func (r *router) RemoveTProxyRule(ruleID string) error {
var remaining []tproxyRuleEntry
for _, entry := range r.tproxyRules {
if entry.ruleID != ruleID {
remaining = append(remaining, entry)
continue
}
if err := r.iptablesClient.DeleteIfExists(entry.table, entry.chain, entry.spec...); err != nil {
log.Debugf("remove tproxy rule %s: %v", ruleID, err)
}
}
r.tproxyRules = remaining
return nil
}
type tproxyRuleEntry struct {
ruleID string
table string
chain string
spec []string
}
func (r *router) buildSourceSpecs(sources []netip.Prefix) [][]string {
if len(sources) == 0 {
return [][]string{{}} // empty spec = match any source
}
specs := make([][]string, 0, len(sources))
for _, src := range sources {
specs = append(specs, []string{"-s", src.String()})
}
return specs
}

View File

@@ -180,6 +180,22 @@ type Manager interface {
// SetupEBPFProxyNoTrack creates static notrack rules for eBPF proxy loopback traffic.
// This prevents conntrack from interfering with WireGuard proxy communication.
SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error
// AddTProxyRule adds TPROXY redirect rules for specific source CIDRs and destination ports.
// Traffic from sources on dstPorts is redirected to the transparent proxy on redirectPort.
// Empty dstPorts means redirect all ports.
AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error
// RemoveTProxyRule removes TPROXY redirect rules by ID.
RemoveTProxyRule(ruleID string) error
// AddUDPInspectionHook registers a hook that inspects UDP packets before forwarding.
// The hook receives the raw packet and returns true to drop it.
// Used for QUIC SNI-based blocking. Returns a hook ID for removal.
AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string
// RemoveUDPInspectionHook removes a previously registered inspection hook.
RemoveUDPInspectionHook(hookID string)
}
func GenKey(format string, pair RouterPair) string {

View File

@@ -14,7 +14,6 @@ import (
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/statemanager"
@@ -218,10 +217,6 @@ func (m *Manager) AllowNetbird() error {
return fmt.Errorf("flush allow input netbird rules: %w", err)
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}
@@ -487,6 +482,28 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return nil
}
// AddTProxyRule adds TPROXY redirect rules for the transparent proxy.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// RemoveTProxyRule removes TPROXY redirect rules by ID.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
m.mutex.Lock()
defer m.mutex.Unlock()
return m.router.RemoveTProxyRule(ruleID)
}
// AddUDPInspectionHook is a no-op for nftables (kernel-mode firewall has no userspace packet hooks).
func (m *Manager) AddUDPInspectionHook(_ uint16, _ func([]byte) bool) string { return "" }
// RemoveUDPInspectionHook is a no-op for nftables.
func (m *Manager) RemoveUDPInspectionHook(_ string) {}
func (m *Manager) initNoTrackChains(table *nftables.Table) error {
m.notrackOutputChain = m.rConn.AddChain(&nftables.Chain{
Name: chainNameRawOutput,

View File

@@ -19,7 +19,6 @@ import (
"golang.org/x/sys/unix"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewall "github.com/netbirdio/netbird/client/firewall/manager"
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
@@ -41,8 +40,6 @@ const (
chainNameForward = "FORWARD"
chainNameMangleForward = "netbird-mangle-forward"
firewalldTableName = "firewalld"
userDataAcceptForwardRuleIif = "frwacceptiif"
userDataAcceptForwardRuleOif = "frwacceptoif"
userDataAcceptInputRule = "inputaccept"
@@ -80,6 +77,7 @@ type router struct {
ipFwdState *ipfwdstate.IPForwardingState
legacyManagement bool
mtu uint16
}
func newRouter(workTable *nftables.Table, wgIface iFaceMapper, mtu uint16) (*router, error) {
@@ -136,10 +134,6 @@ func (r *router) Reset() error {
merr = multierror.Append(merr, fmt.Errorf("remove accept filter rules: %w", err))
}
if err := firewalld.UntrustInterface(r.wgIface.Name()); err != nil {
merr = multierror.Append(merr, err)
}
if err := r.removeNatPreroutingRules(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
}
@@ -287,10 +281,6 @@ func (r *router) createContainers() error {
log.Errorf("failed to add accept rules for the forward chain: %s", err)
}
if err := firewalld.TrustInterface(r.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
if err := r.refreshRulesMap(); err != nil {
log.Errorf("failed to refresh rules: %s", err)
}
@@ -1330,13 +1320,6 @@ func (r *router) isExternalChain(chain *nftables.Chain) bool {
return false
}
// Skip firewalld-owned chains. Firewalld creates its chains with the
// NFT_CHAIN_OWNER flag, so inserting rules into them returns EPERM.
// We delegate acceptance to firewalld by trusting the interface instead.
if chain.Table.Name == firewalldTableName {
return false
}
// Skip all iptables-managed tables in the ip family
if chain.Table.Family == nftables.TableFamilyIPv4 && isIptablesTable(chain.Table.Name) {
return false
@@ -2155,3 +2138,227 @@ func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any
},
}, nil
}
// AddTProxyRule adds nftables TPROXY redirect rules in the mangle prerouting chain.
// Traffic from sources on dstPorts arriving on the WG interface is redirected to
// the transparent proxy listener on redirectPort.
// Separate rules are created for TCP and UDP protocols.
func (r *router) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
// Use the nat redirect chain for DNAT rules.
// TPROXY doesn't work on WG kernel interfaces (socket assignment silently fails),
// so we use DNAT to 127.0.0.1:proxy_port instead. The proxy reads the original
// destination via SO_ORIGINAL_DST (conntrack).
chain := r.chains[chainNameRoutingRdr]
if chain == nil {
return fmt.Errorf("nat redirect chain not initialized")
}
for _, proto := range []uint8{unix.IPPROTO_TCP, unix.IPPROTO_UDP} {
protoName := "tcp"
if proto == unix.IPPROTO_UDP {
protoName = "udp"
}
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, protoName)
if existing, ok := r.rules[ruleKey]; ok && existing.Handle != 0 {
if err := r.decrementSetCounter(existing); err != nil {
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
}
if err := r.conn.DelRule(existing); err != nil {
log.Debugf("remove existing tproxy rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
}
exprs, err := r.buildRedirectExprs(proto, sources, dstPorts, redirectPort)
if err != nil {
return fmt.Errorf("build redirect exprs for %s: %w", protoName, err)
}
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
Table: r.workTable,
Chain: chain,
Exprs: exprs,
UserData: []byte(ruleKey),
})
}
// Accept redirected packets in the ACL input chain. After REDIRECT, the
// destination port becomes the proxy port. Without this rule, the ACL filter
// drops the packet. We match on ct state dnat so only REDIRECT'd connections
// are accepted: direct connections to the proxy port are blocked.
inputAcceptKey := fmt.Sprintf("tproxy-%s-input", ruleID)
if _, ok := r.rules[inputAcceptKey]; !ok {
inputChain := &nftables.Chain{
Name: "netbird-acl-input-rules",
Table: r.workTable,
}
r.rules[inputAcceptKey] = r.conn.InsertRule(&nftables.Rule{
Table: r.workTable,
Chain: inputChain,
Exprs: []expr.Any{
// Only accept connections that were REDIRECT'd (ct status dnat)
&expr.Ct{Register: 1, Key: expr.CtKeySTATUS},
&expr.Bitwise{
SourceRegister: 1,
DestRegister: 1,
Len: 4,
Mask: binaryutil.NativeEndian.PutUint32(0x20), // IPS_DST_NAT
Xor: binaryutil.NativeEndian.PutUint32(0),
},
&expr.Cmp{
Op: expr.CmpOpNeq,
Register: 1,
Data: binaryutil.NativeEndian.PutUint32(0),
},
// Accept both TCP and UDP redirected to the proxy port.
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(redirectPort),
},
&expr.Verdict{Kind: expr.VerdictAccept},
},
UserData: []byte(inputAcceptKey),
})
}
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush tproxy rules for %s: %w", ruleID, err)
}
return nil
}
// RemoveTProxyRule removes TPROXY redirect rules by ID (both TCP and UDP variants).
func (r *router) RemoveTProxyRule(ruleID string) error {
if err := r.refreshRulesMap(); err != nil {
return fmt.Errorf(refreshRulesMapError, err)
}
var removed int
for _, suffix := range []string{"tcp", "udp", "input"} {
ruleKey := fmt.Sprintf("tproxy-%s-%s", ruleID, suffix)
rule, ok := r.rules[ruleKey]
if !ok {
continue
}
if rule.Handle == 0 {
delete(r.rules, ruleKey)
continue
}
if err := r.decrementSetCounter(rule); err != nil {
log.Debugf("decrement set counter for %s: %v", ruleKey, err)
}
if err := r.conn.DelRule(rule); err != nil {
log.Debugf("delete tproxy rule %s: %v", ruleKey, err)
}
delete(r.rules, ruleKey)
removed++
}
if removed > 0 {
if err := r.conn.Flush(); err != nil {
return fmt.Errorf("flush tproxy rule removal for %s: %w", ruleID, err)
}
}
return nil
}
// buildRedirectExprs builds nftables expressions for a REDIRECT rule.
// Matches WG interface ingress, source CIDRs, destination ports, then REDIRECTs to the proxy port.
func (r *router) buildRedirectExprs(proto uint8, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) ([]expr.Any, error) {
var exprs []expr.Any
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: ifname(r.wgIface.Name())},
)
exprs = append(exprs,
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
&expr.Cmp{Op: expr.CmpOpEq, Register: 1, Data: []byte{proto}},
)
// Source CIDRs use the named ipset shared with route rules.
if len(sources) > 0 {
srcSet := firewall.NewPrefixSet(sources)
srcExprs, err := r.getIpSet(srcSet, sources, true)
if err != nil {
return nil, fmt.Errorf("get source ipset: %w", err)
}
exprs = append(exprs, srcExprs...)
}
if len(dstPorts) == 1 {
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Cmp{
Op: expr.CmpOpEq,
Register: 1,
Data: binaryutil.BigEndian.PutUint16(dstPorts[0]),
},
)
} else if len(dstPorts) > 1 {
setElements := make([]nftables.SetElement, len(dstPorts))
for i, p := range dstPorts {
setElements[i] = nftables.SetElement{Key: binaryutil.BigEndian.PutUint16(p)}
}
portSet := &nftables.Set{
Table: r.workTable,
Anonymous: true,
Constant: true,
KeyType: nftables.TypeInetService,
}
if err := r.conn.AddSet(portSet, setElements); err != nil {
return nil, fmt.Errorf("create port set: %w", err)
}
exprs = append(exprs,
&expr.Payload{
DestRegister: 1,
Base: expr.PayloadBaseTransportHeader,
Offset: 2,
Len: 2,
},
&expr.Lookup{
SourceRegister: 1,
SetName: portSet.Name,
SetID: portSet.ID,
},
)
}
// REDIRECT to local proxy port. Changes the destination to the interface's
// primary address + specified port. Conntrack tracks the original destination,
// readable via SO_ORIGINAL_DST.
exprs = append(exprs,
&expr.Immediate{Register: 1, Data: binaryutil.BigEndian.PutUint16(redirectPort)},
&expr.Redir{
RegisterProtoMin: 1,
},
)
return exprs, nil
}

View File

@@ -3,9 +3,6 @@
package uspfilter
import (
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/firewalld"
"github.com/netbirdio/netbird/client/internal/statemanager"
)
@@ -19,9 +16,6 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
if m.nativeFirewall != nil {
return m.nativeFirewall.Close(stateManager)
}
if err := firewalld.UntrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to untrust interface in firewalld: %v", err)
}
return nil
}
@@ -30,8 +24,5 @@ func (m *Manager) AllowNetbird() error {
if m.nativeFirewall != nil {
return m.nativeFirewall.AllowNetbird()
}
if err := firewalld.TrustInterface(m.wgIface.Name()); err != nil {
log.Warnf("failed to trust interface in firewalld: %v", err)
}
return nil
}

View File

@@ -1,37 +0,0 @@
package common
import (
"net/netip"
"sync/atomic"
)
// PacketHook stores a registered hook for a specific IP:port.
type PacketHook struct {
IP netip.Addr
Port uint16
Fn func([]byte) bool
}
// HookMatches checks if a packet's destination matches the hook and invokes it.
func HookMatches(h *PacketHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.IP == dstIP && h.Port == dport {
return h.Fn(packetData)
}
return false
}
// SetHook atomically stores a hook, handling nil removal.
func SetHook(ptr *atomic.Pointer[PacketHook], ip netip.Addr, dPort uint16, hook func([]byte) bool) {
if hook == nil {
ptr.Store(nil)
return
}
ptr.Store(&PacketHook{
IP: ip,
Port: dPort,
Fn: hook,
})
}

View File

@@ -9,7 +9,6 @@ import (
// IFaceMapper defines subset methods of interface required for manager
type IFaceMapper interface {
Name() string
SetFilter(device.PacketFilter) error
Address() wgaddr.Address
GetWGDevice() *wgdevice.Device

View File

@@ -115,13 +115,12 @@ type Manager struct {
localipmanager *localIPManager
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
forwarder atomic.Pointer[forwarder.Forwarder]
pendingCapture atomic.Pointer[forwarder.PacketCapture]
logger *nblog.Logger
flowLogger nftypes.FlowLogger
udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker
forwarder atomic.Pointer[forwarder.Forwarder]
logger *nblog.Logger
flowLogger nftypes.FlowLogger
blockRule firewall.Rule
@@ -143,8 +142,15 @@ type Manager struct {
mssClampEnabled bool
// Only one hook per protocol is supported. Outbound direction only.
udpHookOut atomic.Pointer[common.PacketHook]
tcpHookOut atomic.Pointer[common.PacketHook]
udpHookOut atomic.Pointer[packetHook]
tcpHookOut atomic.Pointer[packetHook]
}
// packetHook stores a registered hook for a specific IP:port.
type packetHook struct {
ip netip.Addr
port uint16
fn func([]byte) bool
}
// decoder for packages
@@ -352,19 +358,6 @@ func (m *Manager) determineRouting() error {
return nil
}
// SetPacketCapture sets or clears packet capture on the forwarder endpoint.
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) {
if pc == nil {
m.pendingCapture.Store(nil)
} else {
m.pendingCapture.Store(&pc)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.SetCapture(pc)
}
}
// initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error {
if m.forwarder.Load() != nil {
@@ -386,11 +379,6 @@ func (m *Manager) initForwarder() error {
m.forwarder.Store(forwarder)
// Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture.
if pc := m.pendingCapture.Load(); pc != nil {
forwarder.SetCapture(*pc)
}
log.Debug("forwarder initialized")
return nil
@@ -633,7 +621,6 @@ func (m *Manager) resetState() {
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.SetCapture(nil)
fwder.Stop()
}
@@ -654,6 +641,45 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
return m.nativeFirewall.SetupEBPFProxyNoTrack(proxyPort, wgPort)
}
// AddTProxyRule delegates to the native firewall for TPROXY rules.
// In userspace mode (no native firewall), this is a no-op since the
// forwarder intercepts traffic directly.
func (m *Manager) AddTProxyRule(ruleID string, sources []netip.Prefix, dstPorts []uint16, redirectPort uint16) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.AddTProxyRule(ruleID, sources, dstPorts, redirectPort)
}
// AddUDPInspectionHook registers a hook for QUIC/UDP inspection via the packet filter.
func (m *Manager) AddUDPInspectionHook(dstPort uint16, hook func(packet []byte) bool) string {
m.SetUDPPacketHook(netip.Addr{}, dstPort, hook)
return "udp-inspection"
}
// RemoveUDPInspectionHook removes a previously registered inspection hook.
func (m *Manager) RemoveUDPInspectionHook(_ string) {
m.SetUDPPacketHook(netip.Addr{}, 0, nil)
}
// RemoveTProxyRule delegates to the native firewall for TPROXY rules.
func (m *Manager) RemoveTProxyRule(ruleID string) error {
if m.nativeFirewall == nil {
return nil
}
return m.nativeFirewall.RemoveTProxyRule(ruleID)
}
// IsLocalIP reports whether the given IP belongs to the local machine.
func (m *Manager) IsLocalIP(ip netip.Addr) bool {
return m.localipmanager.IsLocalIP(ip)
}
// GetForwarder returns the userspace packet forwarder, or nil if not initialized.
func (m *Manager) GetForwarder() *forwarder.Forwarder {
return m.forwarder.Load()
}
// UpdateSet updates the rule destinations associated with the given set
// by merging the existing prefixes with the new ones, then deduplicating.
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
@@ -925,11 +951,21 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
}
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return common.HookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
return hookMatches(m.udpHookOut.Load(), dstIP, dport, packetData)
}
func (m *Manager) tcpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
return common.HookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
return hookMatches(m.tcpHookOut.Load(), dstIP, dport, packetData)
}
func hookMatches(h *packetHook, dstIP netip.Addr, dport uint16, packetData []byte) bool {
if h == nil {
return false
}
if h.ip == dstIP && h.port == dport {
return h.fn(packetData)
}
return false
}
// filterInbound implements filtering logic for incoming packets.
@@ -1340,12 +1376,28 @@ func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, prot
// SetUDPPacketHook sets the outbound UDP packet hook. Pass nil hook to remove.
func (m *Manager) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
common.SetHook(&m.udpHookOut, ip, dPort, hook)
if hook == nil {
m.udpHookOut.Store(nil)
return
}
m.udpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// SetTCPPacketHook sets the outbound TCP packet hook. Pass nil hook to remove.
func (m *Manager) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) {
common.SetHook(&m.tcpHookOut, ip, dPort, hook)
if hook == nil {
m.tcpHookOut.Store(nil)
return
}
m.tcpHookOut.Store(&packetHook{
ip: ip,
port: dPort,
fn: hook,
})
}
// SetLogLevel sets the log level for the firewall manager

View File

@@ -31,20 +31,12 @@ var logger = log.NewFromLogrus(logrus.StandardLogger())
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
type IFaceMock struct {
NameFunc func() string
SetFilterFunc func(device.PacketFilter) error
AddressFunc func() wgaddr.Address
GetWGDeviceFunc func() *wgdevice.Device
GetDeviceFunc func() *device.FilteredDevice
}
func (i *IFaceMock) Name() string {
if i.NameFunc == nil {
return "wgtest"
}
return i.NameFunc()
}
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
if i.GetWGDeviceFunc == nil {
return nil
@@ -210,9 +202,9 @@ func TestSetUDPPacketHook(t *testing.T) {
h := manager.udpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(8000), h.Port)
assert.True(t, h.Fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(8000), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetUDPPacketHook(netip.MustParseAddr("10.168.0.1"), 8000, nil)
@@ -234,9 +226,9 @@ func TestSetTCPPacketHook(t *testing.T) {
h := manager.tcpHookOut.Load()
require.NotNil(t, h)
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.IP)
assert.Equal(t, uint16(53), h.Port)
assert.True(t, h.Fn(nil))
assert.Equal(t, netip.MustParseAddr("10.168.0.1"), h.ip)
assert.Equal(t, uint16(53), h.port)
assert.True(t, h.fn(nil))
assert.True(t, called)
manager.SetTCPPacketHook(netip.MustParseAddr("10.168.0.1"), 53, nil)

View File

@@ -12,19 +12,12 @@ import (
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
)
// PacketCapture captures raw packets for debugging. Implementations must be
// safe for concurrent use and must not block.
type PacketCapture interface {
Offer(data []byte, outbound bool)
}
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct {
logger *nblog.Logger
dispatcher stack.NetworkDispatcher
device *wgdevice.Device
mtu atomic.Uint32
capture atomic.Pointer[PacketCapture]
}
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -61,17 +54,13 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
continue
}
pktBytes := data.AsSlice()
// Send the packet through WireGuard
address := netHeader.DestinationAddress()
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
e.logger.Error1("CreateOutboundPacket: %v", err)
continue
}
if pc := e.capture.Load(); pc != nil {
(*pc).Offer(pktBytes, true)
}
written++
}

View File

@@ -7,6 +7,7 @@ import (
"net/netip"
"runtime"
"sync"
"sync/atomic"
"time"
log "github.com/sirupsen/logrus"
@@ -21,6 +22,7 @@ import (
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
"github.com/netbirdio/netbird/client/inspect"
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -46,6 +48,10 @@ type Forwarder struct {
netstack bool
hasRawICMPAccess bool
pingSemaphore chan struct{}
// proxy is the optional inspection engine.
// When set, TCP connections are handed to the engine for protocol detection
// and rule evaluation. Swapped atomically for lock-free hot-path access.
proxy atomic.Pointer[inspect.Proxy]
}
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool, mtu uint16) (*Forwarder, error) {
@@ -79,7 +85,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
}
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
return nil, fmt.Errorf("failed to add protocol address: %s", err)
return nil, fmt.Errorf("add protocol address: %s", err)
}
defaultSubnet, err := tcpip.NewSubnet(
@@ -139,16 +145,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return f, nil
}
// SetCapture sets or clears the packet capture on the forwarder endpoint.
// This captures outbound packets that bypass the FilteredDevice (netstack forwarding).
func (f *Forwarder) SetCapture(pc PacketCapture) {
if pc == nil {
f.endpoint.capture.Store(nil)
return
}
f.endpoint.capture.Store(&pc)
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload))
@@ -165,6 +161,13 @@ func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
return nil
}
// SetProxy sets the inspection engine. When set, TCP connections are handed
// to it for protocol detection and rule evaluation instead of direct relay.
// Pass nil to disable inspection.
func (f *Forwarder) SetProxy(p *inspect.Proxy) {
f.proxy.Store(p)
}
// Stop gracefully shuts down the forwarder
func (f *Forwarder) Stop() {
f.cancel()
@@ -177,6 +180,25 @@ func (f *Forwarder) Stop() {
f.stack.Wait()
}
// CheckUDPPacket inspects a UDP payload against proxy rules before injection.
// This is called by the filter for QUIC SNI-based blocking.
// Returns true if the packet should be allowed, false if it should be dropped.
func (f *Forwarder) CheckUDPPacket(payload []byte, srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) bool {
p := f.proxy.Load()
if p == nil {
return true
}
dst := netip.AddrPortFrom(dstIP, dstPort)
src := inspect.SourceInfo{
IP: srcIP,
PolicyID: inspect.PolicyID(ruleID),
}
action := p.HandleUDPPacket(payload, dst, src)
return action != inspect.ActionBlock
}
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
if f.netstack && f.ip.Equal(addr) {
return net.IPv4(127, 0, 0, 1)

View File

@@ -270,9 +270,5 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []
return 0
}
if pc := f.endpoint.capture.Load(); pc != nil {
(*pc).Offer(fullPacket, true)
}
return len(fullPacket)
}

View File

@@ -16,6 +16,7 @@ import (
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
"gvisor.dev/gvisor/pkg/waiter"
"github.com/netbirdio/netbird/client/inspect"
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
)
@@ -23,6 +24,86 @@ import (
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
id := r.ID()
// If the inspection engine is configured, accept the connection first and hand it off.
if p := f.proxy.Load(); p != nil {
f.handleTCPWithInspection(r, id, p)
return
}
f.handleTCPDirect(r, id)
}
// handleTCPWithInspection accepts the connection and hands it to the inspection
// engine. For allow decisions, the forwarder does its own relay (passthrough).
// For block/inspect, the engine handles everything internally.
func (f *Forwarder) handleTCPWithInspection(r *tcp.ForwarderRequest, id stack.TransportEndpointID, p *inspect.Proxy) {
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
if epErr != nil {
f.logger.Error1("forwarder: create TCP endpoint for inspection: %v", epErr)
r.Complete(true)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
return
}
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
srcIP := netip.AddrFrom4(id.RemoteAddress.As4())
dstIP := netip.AddrFrom4(id.LocalAddress.As4())
dst := netip.AddrPortFrom(dstIP, id.LocalPort)
var policyID []byte
if ruleID, ok := f.getRuleID(srcIP, dstIP, id.RemotePort, id.LocalPort); ok {
policyID = ruleID
}
src := inspect.SourceInfo{
IP: srcIP,
PolicyID: inspect.PolicyID(policyID),
}
f.logger.Trace1("forwarder: handing TCP %v to inspection engine", epID(id))
go func() {
result, err := p.InspectTCP(f.ctx, inConn, dst, src)
if err != nil && err != inspect.ErrBlocked {
f.logger.Debug2("forwarder: inspection error for %v: %v", epID(id), err)
}
// Passthrough: engine returned allow, forwarder does the relay.
if result.PassthroughConn != nil {
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
outConn, dialErr := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
if dialErr != nil {
f.logger.Trace2("forwarder: passthrough dial error for %v: %v", epID(id), dialErr)
if closeErr := result.PassthroughConn.Close(); closeErr != nil {
f.logger.Debug1("forwarder: close passthrough conn: %v", closeErr)
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
return
}
f.proxyTCPPassthrough(id, result.PassthroughConn, outConn, ep, flowID)
return
}
// Engine handled it (block/inspect/HTTP). Capture stats and clean up.
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
ep.Close()
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, rxPackets, txPackets)
}()
}
// handleTCPDirect handles TCP connections with direct relay (no proxy).
func (f *Forwarder) handleTCPDirect(r *tcp.ForwarderRequest, id stack.TransportEndpointID) {
flowID := uuid.New()
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
@@ -42,7 +123,6 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
return
}
// Create wait queue for blocking syscalls
wq := waiter.Queue{}
ep, epErr := r.CreateEndpoint(&wq)
@@ -55,7 +135,6 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
return
}
// Complete the handshake
r.Complete(false)
inConn := gonet.NewTCPConn(&wq, ep)
@@ -73,7 +152,6 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
go func() {
<-ctx.Done()
// Close connections and endpoint.
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: inConn close error: %v", err)
}
@@ -132,6 +210,66 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
}
// proxyTCPPassthrough relays traffic between a peeked inbound connection
// (from the inspection engine passthrough) and the outbound connection.
// It accepts net.Conn for inConn since the inspection engine wraps it in a peekConn.
func (f *Forwarder) proxyTCPPassthrough(id stack.TransportEndpointID, inConn net.Conn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
ctx, cancel := context.WithCancel(f.ctx)
defer cancel()
go func() {
<-ctx.Done()
if err := inConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: passthrough inConn close: %v", err)
}
if err := outConn.Close(); err != nil && !isClosedError(err) {
f.logger.Debug1("forwarder: passthrough outConn close: %v", err)
}
ep.Close()
}()
var wg sync.WaitGroup
wg.Add(2)
var (
bytesIn int64
bytesOut int64
errIn error
errOut error
)
go func() {
bytesIn, errIn = io.Copy(outConn, inConn)
cancel()
wg.Done()
}()
go func() {
bytesOut, errOut = io.Copy(inConn, outConn)
cancel()
wg.Done()
}()
wg.Wait()
if errIn != nil && !isClosedError(errIn) {
f.logger.Error2("proxyTCPPassthrough: copy error (in→out) for %s: %v", epID(id), errIn)
}
if errOut != nil && !isClosedError(errOut) {
f.logger.Error2("proxyTCPPassthrough: copy error (out→in) for %s: %v", epID(id), errOut)
}
var rxPackets, txPackets uint64
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
rxPackets = tcpStats.SegmentsSent.Value()
txPackets = tcpStats.SegmentsReceived.Value()
}
f.logger.Trace5("forwarder: passthrough TCP %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesOut, txPackets, bytesIn)
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesOut), uint64(bytesIn), rxPackets, txPackets)
}
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
dstIp := netip.AddrFrom4(id.LocalAddress.As4())

View File

@@ -1,90 +0,0 @@
package uspfilter
import (
"encoding/binary"
"net/netip"
"sync/atomic"
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
"github.com/netbirdio/netbird/client/iface/device"
)
const (
ipv4HeaderMinLen = 20
ipv4ProtoOffset = 9
ipv4FlagsOffset = 6
ipv4DstOffset = 16
ipProtoUDP = 17
ipProtoTCP = 6
ipv4FragOffMask = 0x1fff
// dstPortOffset is the offset of the destination port within a UDP or TCP header.
dstPortOffset = 2
)
// HooksFilter is a minimal packet filter that only handles outbound DNS hooks.
// It is installed on the WireGuard interface when the userspace bind is active
// but a full firewall filter (Manager) is not needed because a native kernel
// firewall (nftables/iptables) handles packet filtering.
type HooksFilter struct {
udpHook atomic.Pointer[common.PacketHook]
tcpHook atomic.Pointer[common.PacketHook]
}
var _ device.PacketFilter = (*HooksFilter)(nil)
// FilterOutbound checks outbound packets for DNS hook matches.
// Only IPv4 packets matching the registered hook IP:port are intercepted.
// IPv6 and non-IP packets pass through unconditionally.
func (f *HooksFilter) FilterOutbound(packetData []byte, _ int) bool {
if len(packetData) < ipv4HeaderMinLen {
return false
}
// Only process IPv4 packets, let everything else pass through.
if packetData[0]>>4 != 4 {
return false
}
ihl := int(packetData[0]&0x0f) * 4
if ihl < ipv4HeaderMinLen || len(packetData) < ihl+4 {
return false
}
// Skip non-first fragments: they don't carry L4 headers.
flagsAndOffset := binary.BigEndian.Uint16(packetData[ipv4FlagsOffset : ipv4FlagsOffset+2])
if flagsAndOffset&ipv4FragOffMask != 0 {
return false
}
dstIP, ok := netip.AddrFromSlice(packetData[ipv4DstOffset : ipv4DstOffset+4])
if !ok {
return false
}
proto := packetData[ipv4ProtoOffset]
dstPort := binary.BigEndian.Uint16(packetData[ihl+dstPortOffset : ihl+dstPortOffset+2])
switch proto {
case ipProtoUDP:
return common.HookMatches(f.udpHook.Load(), dstIP, dstPort, packetData)
case ipProtoTCP:
return common.HookMatches(f.tcpHook.Load(), dstIP, dstPort, packetData)
default:
return false
}
}
// FilterInbound allows all inbound packets (native firewall handles filtering).
func (f *HooksFilter) FilterInbound([]byte, int) bool {
return false
}
// SetUDPPacketHook registers the UDP packet hook.
func (f *HooksFilter) SetUDPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.udpHook, ip, dPort, hook)
}
// SetTCPPacketHook registers the TCP packet hook.
func (f *HooksFilter) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func([]byte) bool) {
common.SetHook(&f.tcpHook, ip, dPort, hook)
}

View File

@@ -239,12 +239,8 @@ func TestICEBind_HandlesConcurrentMixedTraffic(t *testing.T) {
ipv6Count++
}
// Allow some UDP packet loss under load (e.g. FreeBSD/QEMU runners). The
// routing-correctness checks above are the real assertions; the counts
// are a sanity bound to catch a totally silent path.
minDelivered := packetsPerFamily * 80 / 100
assert.GreaterOrEqual(t, ipv4Count, minDelivered, "IPv4 delivery below threshold")
assert.GreaterOrEqual(t, ipv6Count, minDelivered, "IPv6 delivery below threshold")
assert.Equal(t, packetsPerFamily, ipv4Count)
assert.Equal(t, packetsPerFamily, ipv6Count)
}
func TestICEBind_DetectsAddressFamilyFromConnection(t *testing.T) {

View File

@@ -3,7 +3,6 @@ package device
import (
"net/netip"
"sync"
"sync/atomic"
"golang.zx2c4.com/wireguard/tun"
)
@@ -29,20 +28,11 @@ type PacketFilter interface {
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
}
// PacketCapture captures raw packets for debugging. Implementations must be
// safe for concurrent use and must not block.
type PacketCapture interface {
// Offer submits a packet for capture. outbound is true for packets
// leaving the host (Read path), false for packets arriving (Write path).
Offer(data []byte, outbound bool)
}
// FilteredDevice to override Read or Write of packets
type FilteredDevice struct {
tun.Device
filter PacketFilter
capture atomic.Pointer[PacketCapture]
mutex sync.RWMutex
closeOnce sync.Once
}
@@ -73,25 +63,20 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err
}
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
if filter != nil {
for i := 0; i < n; i++ {
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
}
}
if filter == nil {
return
}
if pc := d.capture.Load(); pc != nil {
for i := 0; i < n; i++ {
(*pc).Offer(bufs[i][offset:offset+sizes[i]], true)
for i := 0; i < n; i++ {
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
}
}
@@ -100,13 +85,6 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
// Write wraps write method with filtering feature
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
// Capture before filtering so dropped packets are still visible in captures.
if pc := d.capture.Load(); pc != nil {
for _, buf := range bufs {
(*pc).Offer(buf[offset:], false)
}
}
d.mutex.RLock()
filter := d.filter
d.mutex.RUnlock()
@@ -118,10 +96,9 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0
for _, buf := range bufs {
if filter.FilterInbound(buf[offset:], len(buf)) {
dropped++
} else {
if !filter.FilterInbound(buf[offset:], len(buf)) {
filteredBufs = append(filteredBufs, buf)
dropped++
}
}
@@ -136,14 +113,3 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.filter = filter
d.mutex.Unlock()
}
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
// Uses atomic store so the hot path (Read/Write) is a single pointer load
// with no locking overhead when capture is off.
func (d *FilteredDevice) SetCapture(pc PacketCapture) {
if pc == nil {
d.capture.Store(nil)
return
}
d.capture.Store(&pc)
}

View File

@@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) {
t.Errorf("unexpected error: %v", err)
return
}
if n != 1 {
if n != 0 {
t.Errorf("expected n=1, got %d", n)
return
}

View File

@@ -217,6 +217,7 @@ func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
// Close closes the tunnel interface
func (w *WGIface) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
var result *multierror.Error
@@ -224,15 +225,7 @@ func (w *WGIface) Close() error {
result = multierror.Append(result, fmt.Errorf("failed to free WireGuard proxy: %w", err))
}
// Release w.mu before calling w.tun.Close(): the underlying
// wireguard-go device.Close() waits for its send/receive goroutines
// to drain. Some of those goroutines re-enter WGIface methods that
// take w.mu (e.g. the packet filter DNS hook calls GetDevice()), so
// holding the mutex here would deadlock the shutdown path.
tun := w.tun
w.mu.Unlock()
if err := tun.Close(); err != nil {
if err := w.tun.Close(); err != nil {
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
}

View File

@@ -1,113 +0,0 @@
//go:build !android
package iface
import (
"errors"
"sync"
"testing"
"time"
wgdevice "golang.zx2c4.com/wireguard/device"
"golang.zx2c4.com/wireguard/tun/netstack"
"github.com/netbirdio/netbird/client/iface/device"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/iface/wgproxy"
)
// fakeTunDevice implements WGTunDevice and lets the test control when
// Close() returns. It mimics the wireguard-go shutdown path, which blocks
// until its goroutines drain. Some of those goroutines (e.g. the packet
// filter DNS hook in client/internal/dns) call back into WGIface, so if
// WGIface.Close() held w.mu across tun.Close() the shutdown would
// deadlock.
type fakeTunDevice struct {
closeStarted chan struct{}
unblockClose chan struct{}
}
func (f *fakeTunDevice) Create() (device.WGConfigurer, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) Up() (*udpmux.UniversalUDPMuxDefault, error) {
return nil, errors.New("not implemented")
}
func (f *fakeTunDevice) UpdateAddr(wgaddr.Address) error { return nil }
func (f *fakeTunDevice) WgAddress() wgaddr.Address { return wgaddr.Address{} }
func (f *fakeTunDevice) MTU() uint16 { return DefaultMTU }
func (f *fakeTunDevice) DeviceName() string { return "nb-close-test" }
func (f *fakeTunDevice) FilteredDevice() *device.FilteredDevice { return nil }
func (f *fakeTunDevice) Device() *wgdevice.Device { return nil }
func (f *fakeTunDevice) GetNet() *netstack.Net { return nil }
func (f *fakeTunDevice) GetICEBind() device.EndpointManager { return nil }
func (f *fakeTunDevice) Close() error {
close(f.closeStarted)
<-f.unblockClose
return nil
}
type fakeProxyFactory struct{}
func (fakeProxyFactory) GetProxy() wgproxy.Proxy { return nil }
func (fakeProxyFactory) GetProxyPort() uint16 { return 0 }
func (fakeProxyFactory) Free() error { return nil }
// TestWGIface_CloseReleasesMutexBeforeTunClose guards against a deadlock
// that surfaces as a macOS test-timeout in
// TestDNSPermanent_updateUpstream: WGIface.Close() used to hold w.mu
// while waiting for the wireguard-go device goroutines to finish, and
// one of those goroutines (the DNS filter hook) calls back into
// WGIface.GetDevice() which needs the same mutex. The fix is to drop
// the lock before tun.Close() returns control.
func TestWGIface_CloseReleasesMutexBeforeTunClose(t *testing.T) {
tun := &fakeTunDevice{
closeStarted: make(chan struct{}),
unblockClose: make(chan struct{}),
}
w := &WGIface{
tun: tun,
wgProxyFactory: fakeProxyFactory{},
}
closeDone := make(chan error, 1)
go func() {
closeDone <- w.Close()
}()
select {
case <-tun.closeStarted:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
t.Fatal("tun.Close() was never invoked")
}
// Simulate the WireGuard read goroutine calling back into WGIface
// via the packet filter's DNS hook. If Close() still held w.mu
// during tun.Close(), this would block until the test timeout.
getDeviceDone := make(chan struct{})
var wg sync.WaitGroup
wg.Add(1)
go func() {
defer wg.Done()
_ = w.GetDevice()
close(getDeviceDone)
}()
select {
case <-getDeviceDone:
case <-time.After(2 * time.Second):
close(tun.unblockClose)
wg.Wait()
t.Fatal("GetDevice() deadlocked while WGIface.Close was closing the tun")
}
close(tun.unblockClose)
select {
case <-closeDone:
case <-time.After(2 * time.Second):
t.Fatal("WGIface.Close() never returned after the tun was unblocked")
}
}

View File

@@ -171,7 +171,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
}
if u.address.Network.Contains(a) {
log.Warnf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
}
@@ -181,7 +181,7 @@ func (u *UDPConn) performFilterCheck(addr net.Addr) error {
u.addrCache.Store(addr.String(), isRouted)
if isRouted {
// Extra log, as the error only shows up with ICE logging enabled
log.Infof("address %s is part of routed network %s, refusing to write", addr, prefix)
log.Infof("Address %s is part of routed network %s, refusing to write", addr, prefix)
return fmt.Errorf("address %s is part of routed network %s, refusing to write", addr, prefix)
}
}

212
client/inspect/config.go Normal file
View File

@@ -0,0 +1,212 @@
package inspect
import (
"crypto"
"crypto/x509"
"net"
"net/netip"
"net/url"
"strings"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
// InspectResult holds the outcome of connection inspection.
type InspectResult struct {
// Action is the rule evaluation result.
Action Action
// PassthroughConn is the client connection with buffered peeked bytes.
// Non-nil only when Action is ActionAllow and the caller should relay
// (TLS passthrough or non-HTTP/TLS protocol). The caller takes ownership
// and is responsible for closing this connection.
PassthroughConn net.Conn
}
const (
// DefaultTProxyPort is the default TPROXY listener port for kernel mode.
// Override with NB_TPROXY_PORT environment variable.
DefaultTProxyPort = 22080
)
// Action determines how the proxy handles a matched connection.
type Action string
const (
// ActionAllow passes the connection through without decryption.
ActionAllow Action = "allow"
// ActionBlock denies the connection.
ActionBlock Action = "block"
// ActionInspect decrypts (MITM) and inspects the connection.
ActionInspect Action = "inspect"
)
// ProxyMode determines the proxy operating mode.
type ProxyMode string
const (
// ModeBuiltin uses the built-in proxy with rules and optional ICAP.
ModeBuiltin ProxyMode = "builtin"
// ModeEnvoy runs a local envoy sidecar for L7 processing.
// Go manages envoy lifecycle, config generation, and rule evaluation.
// USP path forwards via PROXY protocol v2; kernel path uses nftables redirect.
ModeEnvoy ProxyMode = "envoy"
// ModeExternal forwards all traffic to an external proxy.
ModeExternal ProxyMode = "external"
)
// PolicyID is the management policy identifier associated with a connection.
type PolicyID []byte
// MatchDomain reports whether target matches the pattern.
// If pattern starts with "*.", it matches any subdomain (but not the base itself).
// Otherwise it requires an exact match.
func MatchDomain(pattern, target domain.Domain) bool {
p := pattern.PunycodeString()
t := target.PunycodeString()
if strings.HasPrefix(p, "*.") {
base := p[2:]
return strings.HasSuffix(t, "."+base)
}
return p == t
}
// SourceInfo carries source identity context for rule evaluation.
// The source may be a direct WireGuard peer or a host behind
// a site-to-site gateway.
type SourceInfo struct {
// IP is the original source address from the packet.
IP netip.Addr
// PolicyID is the management policy that allowed this traffic
// through route ACLs.
PolicyID PolicyID
}
// ProtoType identifies a protocol handled by the proxy.
type ProtoType string
const (
ProtoHTTP ProtoType = "http"
ProtoHTTPS ProtoType = "https"
ProtoH2 ProtoType = "h2"
ProtoH3 ProtoType = "h3"
ProtoWebSocket ProtoType = "websocket"
ProtoOther ProtoType = "other"
)
// Rule defines a proxy inspection/filtering rule.
type Rule struct {
// ID uniquely identifies this rule.
ID id.RuleID
// Sources are the source CIDRs this rule applies to.
// Includes both direct peer IPs and routed networks behind gateways.
Sources []netip.Prefix
// Domains are the destination domain patterns to match (via SNI or Host header).
// Supports exact match ("example.com") and wildcard ("*.example.com").
Domains []domain.Domain
// Networks are the destination CIDRs to match.
Networks []netip.Prefix
// Ports are the destination ports to match. Empty means all ports.
Ports []uint16
// Protocols restricts which protocols this rule applies to.
// Empty means all protocols.
Protocols []ProtoType
// Paths are URL path patterns to match (HTTP only, requires inspect for HTTPS).
// Supports prefix ("/api/"), exact ("/login"), and wildcard ("/admin/*").
// Empty means all paths.
Paths []string
// Action determines what to do with matched connections.
Action Action
// Priority controls evaluation order. Lower values are evaluated first.
Priority int
}
// ICAPConfig holds ICAP service configuration.
type ICAPConfig struct {
// ReqModURL is the ICAP REQMOD service URL (e.g., icap://server:1344/reqmod).
ReqModURL *url.URL
// RespModURL is the ICAP RESPMOD service URL (e.g., icap://server:1344/respmod).
RespModURL *url.URL
// MaxConnections is the connection pool size. Zero uses a default.
MaxConnections int
}
// TLSConfig holds the MITM CA configuration for TLS inspection.
type TLSConfig struct {
// CA is the certificate authority used to sign dynamic certificates.
CA *x509.Certificate
// CAKey is the CA's private key.
CAKey crypto.PrivateKey
}
// Config holds the transparent proxy configuration.
type Config struct {
// Enabled controls whether the proxy is active.
Enabled bool
// Mode selects built-in or external proxy operation.
Mode ProxyMode
// ExternalURL is the upstream proxy URL for ModeExternal.
// Supports http:// and socks5:// schemes.
ExternalURL *url.URL
// DefaultAction applies when no rule matches a connection.
DefaultAction Action
// RedirectSources are the source CIDRs whose traffic should be intercepted.
// Admin decides: "activate for these users/subnets."
// Used for both kernel TPROXY rules and userspace forwarder source filtering.
RedirectSources []netip.Prefix
// RedirectPorts are the destination ports to intercept. Empty means all ports.
RedirectPorts []uint16
// Rules are the proxy inspection/filtering rules, evaluated in Priority order.
Rules []Rule
// ICAP holds ICAP service configuration. Nil disables ICAP.
ICAP *ICAPConfig
// TLS holds the MITM CA. Nil means no MITM capability (ActionInspect rules ignored).
TLS *TLSConfig
// Envoy configuration (ModeEnvoy only)
Envoy *EnvoyConfig
// ListenAddr is the TPROXY listen address for kernel mode.
// Zero value disables the TPROXY listener.
ListenAddr netip.AddrPort
// WGNetwork is the WireGuard overlay network prefix.
// The proxy blocks dialing destinations inside this network.
WGNetwork netip.Prefix
// LocalIPChecker reports whether an IP belongs to the routing peer.
// Used to prevent SSRF to local services. May be nil.
LocalIPChecker LocalIPChecker
}
// EnvoyConfig holds configuration for the envoy sidecar mode.
type EnvoyConfig struct {
// BinaryPath is the path to the envoy binary.
// Empty means search $PATH for "envoy".
BinaryPath string
// AdminPort is the port for envoy's admin API (health checks, stats).
// Zero means auto-assign.
AdminPort uint16
// Snippets are user-provided config fragments merged into the generated bootstrap.
Snippets *EnvoySnippets
}
// EnvoySnippets holds user-provided YAML fragments for envoy config customization.
// Only safe snippet types are allowed: filters (HTTP and network) and clusters
// needed as dependencies for filter services. Listeners and bootstrap overrides
// are not exposed since we manage the listener and bootstrap.
type EnvoySnippets struct {
// HTTPFilters is YAML injected into the HCM filter chain before the router filter.
// Used for ext_authz, rate limiting, Lua, Wasm, RBAC, JWT auth, etc.
HTTPFilters string
// NetworkFilters is YAML injected into the TLS filter chain before tcp_proxy.
// Used for network-level RBAC, rate limiting, ext_authz on raw TCP.
NetworkFilters string
// Clusters is YAML for additional upstream clusters referenced by filters.
// Needed when filters call external services (ext_authz backend, rate limit service).
Clusters string
}

View File

@@ -0,0 +1,93 @@
package inspect
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/shared/management/domain"
)
func TestMatchDomain(t *testing.T) {
tests := []struct {
name string
pattern string
target string
want bool
}{
{
name: "exact match",
pattern: "example.com",
target: "example.com",
want: true,
},
{
name: "exact no match",
pattern: "example.com",
target: "other.com",
want: false,
},
{
name: "wildcard matches subdomain",
pattern: "*.example.com",
target: "foo.example.com",
want: true,
},
{
name: "wildcard matches deep subdomain",
pattern: "*.example.com",
target: "a.b.c.example.com",
want: true,
},
{
name: "wildcard does not match base",
pattern: "*.example.com",
target: "example.com",
want: false,
},
{
name: "wildcard does not match unrelated",
pattern: "*.example.com",
target: "foo.other.com",
want: false,
},
{
name: "case insensitive exact match",
pattern: "Example.COM",
target: "example.com",
want: true,
},
{
name: "case insensitive wildcard match",
pattern: "*.Example.COM",
target: "FOO.example.com",
want: true,
},
{
name: "wildcard does not match partial suffix",
pattern: "*.example.com",
target: "notexample.com",
want: false,
},
{
name: "unicode domain punycode match",
pattern: "*.münchen.de",
target: "sub.xn--mnchen-3ya.de",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
pattern, err := domain.FromString(tt.pattern)
require.NoError(t, err)
target, err := domain.FromString(tt.target)
require.NoError(t, err)
got := MatchDomain(pattern, target)
assert.Equal(t, tt.want, got)
})
}
}

View File

@@ -0,0 +1,25 @@
package inspect
import (
"net"
"syscall"
)
// newOutboundDialer creates a net.Dialer that clears the socket fwmark.
// In kernel TPROXY mode, accepted connections inherit the TPROXY fwmark.
// Without clearing it, outbound connections from the proxy would match
// the ip rule (fwmark -> local loopback) and loop back to the proxy
// instead of reaching the real destination.
func newOutboundDialer() net.Dialer {
return net.Dialer{
Control: func(_, _ string, c syscall.RawConn) error {
var sockErr error
if err := c.Control(func(fd uintptr) {
sockErr = syscall.SetsockoptInt(int(fd), syscall.SOL_SOCKET, syscall.SO_MARK, 0)
}); err != nil {
return err
}
return sockErr
},
}
}

View File

@@ -0,0 +1,11 @@
//go:build !linux
package inspect
import "net"
// newOutboundDialer returns a plain dialer on non-Linux platforms.
// TPROXY is Linux-only, so no fwmark clearing is needed.
func newOutboundDialer() net.Dialer {
return net.Dialer{}
}

298
client/inspect/envoy.go Normal file
View File

@@ -0,0 +1,298 @@
package inspect
import (
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"os"
"os/exec"
"path/filepath"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
envoyStartTimeout = 15 * time.Second
envoyHealthInterval = 500 * time.Millisecond
envoyStopTimeout = 10 * time.Second
envoyDrainTime = 5
)
// envoyManager manages the lifecycle of an envoy sidecar process.
type envoyManager struct {
log *log.Entry
cmd *exec.Cmd
configPath string
listenPort uint16
adminPort uint16
cancel context.CancelFunc
blockPagePath string
mu sync.Mutex
running bool
}
// startEnvoy finds the envoy binary, generates config, and spawns the process.
// It blocks until envoy reports healthy or the timeout expires.
func startEnvoy(ctx context.Context, logger *log.Entry, config Config) (*envoyManager, error) {
envCfg := config.Envoy
if envCfg == nil {
return nil, fmt.Errorf("envoy config is nil")
}
binaryPath, err := findEnvoyBinary(envCfg.BinaryPath)
if err != nil {
return nil, fmt.Errorf("find envoy binary: %w", err)
}
// Pick admin port
adminPort := envCfg.AdminPort
if adminPort == 0 {
p, err := findFreePort()
if err != nil {
return nil, fmt.Errorf("find free admin port: %w", err)
}
adminPort = p
}
// Pick listener port
listenPort, err := findFreePort()
if err != nil {
return nil, fmt.Errorf("find free listener port: %w", err)
}
// Use a private temp directory (0700) to prevent local attackers from
// replacing the config file between write and envoy read.
configDir, err := os.MkdirTemp("", "nb-envoy-*")
if err != nil {
return nil, fmt.Errorf("create envoy config directory: %w", err)
}
// Write the block page HTML for envoy's direct_response to reference.
blockPagePath := filepath.Join(configDir, "block.html")
blockHTML := fmt.Sprintf(blockPageHTML, "blocked domain", "this domain")
if err := os.WriteFile(blockPagePath, []byte(blockHTML), 0600); err != nil {
return nil, fmt.Errorf("write envoy block page: %w", err)
}
// Generate config with the block page path embedded.
bootstrap, err := generateBootstrap(config, listenPort, adminPort, blockPagePath)
if err != nil {
return nil, fmt.Errorf("generate envoy bootstrap: %w", err)
}
configPath := filepath.Join(configDir, "bootstrap.yaml")
if err := os.WriteFile(configPath, bootstrap, 0600); err != nil {
return nil, fmt.Errorf("write envoy config: %w", err)
}
ctx, cancel := context.WithCancel(ctx)
cmd := exec.CommandContext(ctx, binaryPath,
"-c", configPath,
"--drain-time-s", fmt.Sprintf("%d", envoyDrainTime),
)
// Pipe envoy output to our logger.
cmd.Stdout = &logWriter{entry: logger, level: log.DebugLevel}
cmd.Stderr = &logWriter{entry: logger, level: log.WarnLevel}
if err := cmd.Start(); err != nil {
cancel()
os.Remove(configPath)
return nil, fmt.Errorf("start envoy: %w", err)
}
mgr := &envoyManager{
log: logger,
cmd: cmd,
configPath: configPath,
listenPort: listenPort,
adminPort: adminPort,
blockPagePath: blockPagePath,
cancel: cancel,
running: true,
}
// Wait for envoy to become healthy.
if err := mgr.waitHealthy(ctx); err != nil {
mgr.Stop()
return nil, fmt.Errorf("wait for envoy readiness: %w", err)
}
logger.Infof("inspect: envoy started (pid=%d, listen=%d, admin=%d)", cmd.Process.Pid, listenPort, adminPort)
// Monitor process exit in background.
go mgr.monitor()
return mgr, nil
}
// ListenAddr returns the address envoy listens on for forwarded connections.
func (m *envoyManager) ListenAddr() netip.AddrPort {
return netip.AddrPortFrom(netip.AddrFrom4([4]byte{127, 0, 0, 1}), m.listenPort)
}
// AdminAddr returns the envoy admin API address.
func (m *envoyManager) AdminAddr() string {
return fmt.Sprintf("127.0.0.1:%d", m.adminPort)
}
// Reload writes a new config and sends SIGHUP to envoy.
func (m *envoyManager) Reload(config Config) error {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return fmt.Errorf("envoy is not running")
}
bootstrap, err := generateBootstrap(config, m.listenPort, m.adminPort, m.blockPagePath)
if err != nil {
return fmt.Errorf("generate envoy bootstrap: %w", err)
}
if err := os.WriteFile(m.configPath, bootstrap, 0600); err != nil {
return fmt.Errorf("write envoy config: %w", err)
}
if err := signalReload(m.cmd.Process); err != nil {
return fmt.Errorf("signal envoy reload: %w", err)
}
m.log.Debugf("inspect: envoy config reloaded")
return nil
}
// Healthy checks the envoy admin API /ready endpoint.
func (m *envoyManager) Healthy() bool {
resp, err := http.Get(fmt.Sprintf("http://%s/ready", m.AdminAddr()))
if err != nil {
return false
}
defer resp.Body.Close()
return resp.StatusCode == http.StatusOK
}
// Stop terminates the envoy process and cleans up.
func (m *envoyManager) Stop() {
m.mu.Lock()
defer m.mu.Unlock()
if !m.running {
return
}
m.running = false
m.cancel()
if m.cmd.Process != nil {
done := make(chan struct{})
go func() {
m.cmd.Wait()
close(done)
}()
select {
case <-done:
case <-time.After(envoyStopTimeout):
m.log.Warnf("inspect: envoy did not exit in %s, killing", envoyStopTimeout)
m.cmd.Process.Kill()
<-done
}
}
os.RemoveAll(filepath.Dir(m.configPath))
m.log.Infof("inspect: envoy stopped")
}
// waitHealthy polls the admin API until envoy is ready or timeout.
func (m *envoyManager) waitHealthy(ctx context.Context) error {
deadline := time.After(envoyStartTimeout)
ticker := time.NewTicker(envoyHealthInterval)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
return ctx.Err()
case <-deadline:
return fmt.Errorf("envoy not ready after %s", envoyStartTimeout)
case <-ticker.C:
if m.Healthy() {
return nil
}
}
}
}
// monitor watches for unexpected envoy exits.
func (m *envoyManager) monitor() {
err := m.cmd.Wait()
m.mu.Lock()
wasRunning := m.running
m.running = false
m.mu.Unlock()
if wasRunning {
m.log.Errorf("inspect: envoy exited unexpectedly: %v", err)
}
}
// findEnvoyBinary resolves the envoy binary path.
func findEnvoyBinary(configPath string) (string, error) {
if configPath != "" {
if _, err := os.Stat(configPath); err != nil {
return "", fmt.Errorf("envoy binary not found at %s: %w", configPath, err)
}
return configPath, nil
}
path, err := exec.LookPath("envoy")
if err != nil {
return "", fmt.Errorf("envoy not found in PATH: %w", err)
}
return path, nil
}
// findFreePort asks the OS for an available TCP port.
func findFreePort() (uint16, error) {
ln, err := net.Listen("tcp", "127.0.0.1:0")
if err != nil {
return 0, err
}
port := uint16(ln.Addr().(*net.TCPAddr).Port)
ln.Close()
return port, nil
}
// logWriter adapts log.Entry to io.Writer for piping process output.
type logWriter struct {
entry *log.Entry
level log.Level
}
func (w *logWriter) Write(p []byte) (int, error) {
msg := strings.TrimRight(string(p), "\n\r")
if msg == "" {
return len(p), nil
}
switch w.level {
case log.WarnLevel:
w.entry.Warn(msg)
default:
w.entry.Debug(msg)
}
return len(p), nil
}
// Ensure logWriter satisfies io.Writer.
var _ io.Writer = (*logWriter)(nil)

View File

@@ -0,0 +1,382 @@
package inspect
import (
"bytes"
"fmt"
"strings"
"text/template"
)
// envoyBootstrapTmpl generates the full envoy bootstrap with rule translation.
// TLS rules become per-SNI filter chains; HTTP rules become per-domain virtual hosts.
var envoyBootstrapTmpl = template.Must(template.New("bootstrap").Funcs(template.FuncMap{
"quote": func(s string) string { return fmt.Sprintf("%q", s) },
}).Parse(`node:
id: netbird-inspect
cluster: netbird
admin:
address:
socket_address:
address: 127.0.0.1
port_value: {{.AdminPort}}
static_resources:
listeners:
- name: inspect_listener
address:
socket_address:
address: 127.0.0.1
port_value: {{.ListenPort}}
listener_filters:
- name: envoy.filters.listener.proxy_protocol
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.listener.proxy_protocol.v3.ProxyProtocol
- name: envoy.filters.listener.tls_inspector
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.listener.tls_inspector.v3.TlsInspector
filter_chains:
{{- /* TLS filter chains: per-SNI block/allow + default */ -}}
{{- range .TLSChains}}
- filter_chain_match:
transport_protocol: tls
{{- if .ServerNames}}
server_names:
{{- range .ServerNames}}
- {{quote .}}
{{- end}}
{{- end}}
filters:
{{$.NetworkFiltersSnippet}} - name: envoy.filters.network.tcp_proxy
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.tcp_proxy.v3.TcpProxy
stat_prefix: {{.StatPrefix}}
cluster: original_dst
access_log:
- name: envoy.access_loggers.stderr
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
log_format:
text_format: "[%START_TIME%] tcp %DOWNSTREAM_REMOTE_ADDRESS% -> %UPSTREAM_HOST% %RESPONSE_FLAGS% %DURATION%ms\n"
{{- end}}
{{- /* Plain HTTP filter chain with per-domain virtual hosts */}}
- filters:
- name: envoy.filters.network.http_connection_manager
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.network.http_connection_manager.v3.HttpConnectionManager
stat_prefix: inspect_http
access_log:
- name: envoy.access_loggers.stderr
typed_config:
"@type": type.googleapis.com/envoy.extensions.access_loggers.stream.v3.StderrAccessLog
log_format:
text_format: "[%START_TIME%] http %DOWNSTREAM_REMOTE_ADDRESS% %REQ(:AUTHORITY)% %REQ(:METHOD)% %REQ(X-ENVOY-ORIGINAL-PATH?:PATH)% %RESPONSE_CODE% %RESPONSE_FLAGS% %DURATION%ms\n"
http_filters:
{{.HTTPFiltersSnippet}} - name: envoy.filters.http.router
typed_config:
"@type": type.googleapis.com/envoy.extensions.filters.http.router.v3.Router
route_config:
virtual_hosts:
{{- range .VirtualHosts}}
- name: {{.Name}}
domains: [{{.DomainsStr}}]
routes:
{{- range .Routes}}
- match:
prefix: "{{if .PathPrefix}}{{.PathPrefix}}{{else}}/{{end}}"
{{- if .Block}}
direct_response:
status: 403
body:
filename: "{{$.BlockPagePath}}"
{{- else}}
route:
cluster: original_dst
{{- end}}
{{- end}}
{{- end}}
clusters:
- name: original_dst
type: ORIGINAL_DST
lb_policy: CLUSTER_PROVIDED
connect_timeout: 10s
{{.ExtraClusters}}`))
// tlsChain represents a TLS filter chain entry for the template.
// All TLS chains are passthrough (block decisions happen in Go before envoy).
type tlsChain struct {
// ServerNames restricts this chain to specific SNIs. Empty is catch-all.
ServerNames []string
StatPrefix string
}
// envoyRoute represents a single route entry within a virtual host.
type envoyRoute struct {
// PathPrefix for envoy prefix match. Empty means catch-all "/".
PathPrefix string
Block bool
}
// virtualHost represents an HTTP virtual host entry for the template.
type virtualHost struct {
Name string
// DomainsStr is pre-formatted for the template: "a", "b".
DomainsStr string
Routes []envoyRoute
}
type bootstrapData struct {
AdminPort uint16
ListenPort uint16
BlockPagePath string
TLSChains []tlsChain
VirtualHosts []virtualHost
HTTPFiltersSnippet string
NetworkFiltersSnippet string
ExtraClusters string
}
// generateBootstrap produces the envoy bootstrap YAML from the inspect config.
// Translates inspection rules into envoy-native per-SNI and per-domain routing.
// blockPagePath is the path to the HTML block page file served by direct_response.
func generateBootstrap(config Config, listenPort, adminPort uint16, blockPagePath string) ([]byte, error) {
data := bootstrapData{
AdminPort: adminPort,
BlockPagePath: blockPagePath,
ListenPort: listenPort,
TLSChains: buildTLSChains(config),
VirtualHosts: buildVirtualHosts(config),
}
if config.Envoy != nil && config.Envoy.Snippets != nil {
s := config.Envoy.Snippets
data.HTTPFiltersSnippet = indentSnippet(s.HTTPFilters, 18)
data.NetworkFiltersSnippet = indentSnippet(s.NetworkFilters, 12)
data.ExtraClusters = indentSnippet(s.Clusters, 4)
}
var buf bytes.Buffer
if err := envoyBootstrapTmpl.Execute(&buf, data); err != nil {
return nil, fmt.Errorf("execute bootstrap template: %w", err)
}
return buf.Bytes(), nil
}
// buildTLSChains translates inspection rules into envoy TLS filter chains.
// Block rules -> per-SNI chain routing to blackhole.
// Allow rules (when default=block) -> per-SNI chain routing to original_dst.
// Default chain follows DefaultAction.
func buildTLSChains(config Config) []tlsChain {
// TLS block decisions happen in Go before forwarding to envoy, so we only
// generate allow/passthrough chains here. Envoy can't cleanly close a TLS
// connection without completing a handshake, so blocked SNIs never reach envoy.
var allowed []string
for _, rule := range config.Rules {
if !ruleTouchesProtocol(rule, ProtoHTTPS, ProtoH2) {
continue
}
for _, d := range rule.Domains {
sni := d.PunycodeString()
if rule.Action == ActionAllow || rule.Action == ActionInspect {
allowed = append(allowed, sni)
}
}
}
var chains []tlsChain
if len(allowed) > 0 && config.DefaultAction == ActionBlock {
chains = append(chains, tlsChain{
ServerNames: allowed,
StatPrefix: "tls_allowed",
})
}
// Default catch-all: passthrough (blocked SNIs never arrive here)
chains = append(chains, tlsChain{
StatPrefix: "tls_default",
})
return chains
}
// buildVirtualHosts translates inspection rules into envoy HTTP virtual hosts.
// Groups rules by domain, generates per-path routes within each virtual host.
func buildVirtualHosts(config Config) []virtualHost {
// Group rules by domain for per-domain virtual hosts.
type domainRules struct {
domains []string
routes []envoyRoute
}
domainRouteMap := make(map[string][]envoyRoute)
for _, rule := range config.Rules {
if !ruleTouchesProtocol(rule, ProtoHTTP, ProtoWebSocket) {
continue
}
isBlock := rule.Action == ActionBlock
// Rules without domains or paths are handled by the default action.
if len(rule.Domains) == 0 && len(rule.Paths) == 0 {
continue
}
// Build routes for this rule's paths
var routes []envoyRoute
if len(rule.Paths) > 0 {
for _, p := range rule.Paths {
// Convert our path patterns to envoy prefix match.
// Strip trailing * for envoy prefix matching.
prefix := strings.TrimSuffix(p, "*")
routes = append(routes, envoyRoute{PathPrefix: prefix, Block: isBlock})
}
} else {
routes = append(routes, envoyRoute{Block: isBlock})
}
if len(rule.Domains) > 0 {
for _, d := range rule.Domains {
host := d.PunycodeString()
domainRouteMap[host] = append(domainRouteMap[host], routes...)
}
} else {
// No domain: applies to all, add to default host
domainRouteMap["*"] = append(domainRouteMap["*"], routes...)
}
}
var hosts []virtualHost
idx := 0
// Per-domain virtual hosts with path routes
for domain, routes := range domainRouteMap {
if domain == "*" {
continue
}
// Add a catch-all route after path-specific routes.
// The catch-all follows the default action.
routes = append(routes, envoyRoute{Block: config.DefaultAction == ActionBlock})
hosts = append(hosts, virtualHost{
Name: fmt.Sprintf("domain_%d", idx),
DomainsStr: fmt.Sprintf("%q", domain),
Routes: routes,
})
idx++
}
// Default virtual host (catch-all for unmatched domains)
defaultRoutes := domainRouteMap["*"]
defaultRoutes = append(defaultRoutes, envoyRoute{Block: config.DefaultAction == ActionBlock})
hosts = append(hosts, virtualHost{
Name: "default",
DomainsStr: `"*"`,
Routes: defaultRoutes,
})
return hosts
}
// ruleTouchesProtocol returns true if the rule's protocol list includes any of the given protocols,
// or if the protocol list is empty (matches all).
func ruleTouchesProtocol(rule Rule, protos ...ProtoType) bool {
if len(rule.Protocols) == 0 {
return true
}
for _, rp := range rule.Protocols {
for _, p := range protos {
if rp == p {
return true
}
}
}
return false
}
// indentSnippet prepends each line of the YAML snippet with the given number of spaces.
// Returns empty string if snippet is empty.
func indentSnippet(snippet string, spaces int) string {
if snippet == "" {
return ""
}
prefix := make([]byte, spaces)
for i := range prefix {
prefix[i] = ' '
}
var buf bytes.Buffer
for i, line := range bytes.Split([]byte(snippet), []byte("\n")) {
if i > 0 {
buf.WriteByte('\n')
}
if len(line) > 0 {
buf.Write(prefix)
buf.Write(line)
}
}
buf.WriteByte('\n')
return buf.String()
}
// ValidateSnippets checks that user-provided snippets are safe to inject
// into the envoy config. Returns an error describing the first violation found.
//
// Validation rules:
// - Each snippet must be valid YAML (prevents syntax-level injection)
// - Snippets must not contain YAML document separators (--- or ...) that could
// break out of the indentation context
// - Snippets must only contain list items (starting with "- ") at the top level,
// matching what envoy expects for filters and clusters
func ValidateSnippets(snippets *EnvoySnippets) error {
if snippets == nil {
return nil
}
fields := []struct {
name string
value string
}{
{"http_filters", snippets.HTTPFilters},
{"network_filters", snippets.NetworkFilters},
{"clusters", snippets.Clusters},
}
for _, f := range fields {
if f.value == "" {
continue
}
if err := validateSnippetYAML(f.name, f.value); err != nil {
return err
}
}
return nil
}
func validateSnippetYAML(name, snippet string) error {
// Check for YAML document markers that could break template structure.
for _, line := range strings.Split(snippet, "\n") {
trimmed := strings.TrimSpace(line)
if trimmed == "---" || trimmed == "..." {
return fmt.Errorf("snippet %q: YAML document separators (--- or ...) are not allowed", name)
}
}
// Verify it's valid YAML by checking it doesn't cause template execution issues.
// We can't import yaml.v3 here without adding a dependency, so we do structural checks.
// Check for null bytes or control characters that could confuse YAML parsers.
for i, b := range []byte(snippet) {
if b == 0 {
return fmt.Errorf("snippet %q: null byte at position %d", name, i)
}
if b < 0x09 || (b > 0x0D && b < 0x20 && b != 0x1B) {
return fmt.Errorf("snippet %q: control character 0x%02x at position %d", name, b, i)
}
}
return nil
}

View File

@@ -0,0 +1,88 @@
package inspect
import (
"context"
"encoding/binary"
"fmt"
"net"
"net/netip"
)
// PROXY protocol v2 constants (RFC 7239 / HAProxy spec)
var proxyV2Signature = [12]byte{
0x0D, 0x0A, 0x0D, 0x0A, 0x00, 0x0D, 0x0A, 0x51,
0x55, 0x49, 0x54, 0x0A,
}
const (
proxyV2VersionCommand = 0x21 // version 2, PROXY command
proxyV2FamilyTCP4 = 0x11 // AF_INET, STREAM
proxyV2FamilyTCP6 = 0x21 // AF_INET6, STREAM
)
// forwardToEnvoy forwards a connection to the given envoy sidecar via PROXY protocol v2.
// The caller provides the envoy manager snapshot to avoid accessing p.envoy without lock.
func (p *Proxy) forwardToEnvoy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo, em *envoyManager) error {
envoyAddr := em.ListenAddr()
conn, err := (&net.Dialer{}).DialContext(ctx, "tcp", envoyAddr.String())
if err != nil {
return fmt.Errorf("dial envoy at %s: %w", envoyAddr, err)
}
defer func() {
if err := conn.Close(); err != nil {
p.log.Debugf("close envoy conn: %v", err)
}
}()
if err := writeProxyV2Header(conn, src.IP, dst); err != nil {
return fmt.Errorf("write PROXY v2 header: %w", err)
}
p.log.Tracef("envoy: forwarded %s -> %s via PROXY v2", src.IP, dst)
return relay(ctx, pconn, conn)
}
// writeProxyV2Header writes a PROXY protocol v2 header to w.
// The header encodes the original source IP and the destination address:port.
func writeProxyV2Header(w net.Conn, srcIP netip.Addr, dst netip.AddrPort) error {
srcIP = srcIP.Unmap()
dstIP := dst.Addr().Unmap()
var (
family byte
addrs []byte
)
if srcIP.Is4() && dstIP.Is4() {
family = proxyV2FamilyTCP4
s4 := srcIP.As4()
d4 := dstIP.As4()
addrs = make([]byte, 12) // 4+4+2+2
copy(addrs[0:4], s4[:])
copy(addrs[4:8], d4[:])
binary.BigEndian.PutUint16(addrs[8:10], 0) // src port unknown
binary.BigEndian.PutUint16(addrs[10:12], dst.Port())
} else {
family = proxyV2FamilyTCP6
s16 := srcIP.As16()
d16 := dstIP.As16()
addrs = make([]byte, 36) // 16+16+2+2
copy(addrs[0:16], s16[:])
copy(addrs[16:32], d16[:])
binary.BigEndian.PutUint16(addrs[32:34], 0) // src port unknown
binary.BigEndian.PutUint16(addrs[34:36], dst.Port())
}
// Header: signature(12) + ver_cmd(1) + family(1) + len(2) + addrs
header := make([]byte, 16+len(addrs))
copy(header[0:12], proxyV2Signature[:])
header[12] = proxyV2VersionCommand
header[13] = family
binary.BigEndian.PutUint16(header[14:16], uint16(len(addrs)))
copy(header[16:], addrs)
_, err := w.Write(header)
return err
}

View File

@@ -0,0 +1,13 @@
//go:build !windows
package inspect
import (
"os"
"syscall"
)
// signalReload sends SIGHUP to the envoy process to trigger config reload.
func signalReload(p *os.Process) error {
return p.Signal(syscall.SIGHUP)
}

View File

@@ -0,0 +1,13 @@
//go:build windows
package inspect
import (
"fmt"
"os"
)
// signalReload is not supported on Windows. Envoy must be restarted.
func signalReload(_ *os.Process) error {
return fmt.Errorf("envoy config reload via signal not supported on Windows")
}

229
client/inspect/external.go Normal file
View File

@@ -0,0 +1,229 @@
package inspect
import (
"bufio"
"context"
"encoding/base64"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"net/url"
"time"
)
const (
externalDialTimeout = 10 * time.Second
)
// handleExternal forwards the connection to an external proxy.
// For TLS connections, it uses HTTP CONNECT to tunnel through the proxy.
// For HTTP connections, it rewrites the request to use the proxy.
func (p *Proxy) handleExternal(ctx context.Context, pconn *peekConn, dst netip.AddrPort) error {
p.mu.RLock()
proxyURL := p.config.ExternalURL
p.mu.RUnlock()
if proxyURL == nil {
return fmt.Errorf("external proxy URL not configured")
}
switch proxyURL.Scheme {
case "http", "https":
return p.externalHTTPProxy(ctx, pconn, dst, proxyURL)
case "socks5":
return p.externalSOCKS5(ctx, pconn, dst, proxyURL)
default:
return fmt.Errorf("unsupported external proxy scheme: %s", proxyURL.Scheme)
}
}
// externalHTTPProxy tunnels through an HTTP proxy using CONNECT.
func (p *Proxy) externalHTTPProxy(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
proxyAddr := proxyURL.Host
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
proxyAddr = net.JoinHostPort(proxyAddr, "8080")
}
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
if err != nil {
return fmt.Errorf("dial external proxy %s: %w", proxyAddr, err)
}
defer func() {
if err := proxyConn.Close(); err != nil {
p.log.Debugf("close external proxy conn: %v", err)
}
}()
connectReq := fmt.Sprintf("CONNECT %s HTTP/1.1\r\nHost: %s\r\n", dst.String(), dst.String())
if proxyURL.User != nil {
connectReq += "Proxy-Authorization: Basic " + basicAuth(proxyURL.User) + "\r\n"
}
connectReq += "\r\n"
if _, err := io.WriteString(proxyConn, connectReq); err != nil {
return fmt.Errorf("send CONNECT to proxy: %w", err)
}
resp, err := http.ReadResponse(bufio.NewReader(proxyConn), nil)
if err != nil {
return fmt.Errorf("read CONNECT response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close CONNECT resp body: %v", err)
}
if resp.StatusCode != http.StatusOK {
return fmt.Errorf("proxy CONNECT failed: %s", resp.Status)
}
return relay(ctx, pconn, proxyConn)
}
// externalSOCKS5 tunnels through a SOCKS5 proxy.
func (p *Proxy) externalSOCKS5(ctx context.Context, pconn *peekConn, dst netip.AddrPort, proxyURL *url.URL) error {
proxyAddr := proxyURL.Host
if _, _, err := net.SplitHostPort(proxyAddr); err != nil {
proxyAddr = net.JoinHostPort(proxyAddr, "1080")
}
proxyConn, err := (&net.Dialer{Timeout: externalDialTimeout}).DialContext(ctx, "tcp", proxyAddr)
if err != nil {
return fmt.Errorf("dial SOCKS5 proxy %s: %w", proxyAddr, err)
}
defer func() {
if err := proxyConn.Close(); err != nil {
p.log.Debugf("close SOCKS5 proxy conn: %v", err)
}
}()
if err := socks5Handshake(proxyConn, dst, proxyURL.User); err != nil {
return fmt.Errorf("SOCKS5 handshake: %w", err)
}
return relay(ctx, pconn, proxyConn)
}
// socks5Handshake performs the SOCKS5 handshake to connect through the proxy.
func socks5Handshake(conn net.Conn, dst netip.AddrPort, userinfo *url.Userinfo) error {
needAuth := userinfo != nil
// Greeting
var methods []byte
if needAuth {
methods = []byte{0x00, 0x02} // no auth, username/password
} else {
methods = []byte{0x00} // no auth
}
greeting := append([]byte{0x05, byte(len(methods))}, methods...)
if _, err := conn.Write(greeting); err != nil {
return fmt.Errorf("send greeting: %w", err)
}
// Server method selection
var methodResp [2]byte
if _, err := io.ReadFull(conn, methodResp[:]); err != nil {
return fmt.Errorf("read method selection: %w", err)
}
if methodResp[0] != 0x05 {
return fmt.Errorf("unexpected SOCKS version: %d", methodResp[0])
}
// Handle authentication if selected
if methodResp[1] == 0x02 {
if err := socks5Auth(conn, userinfo); err != nil {
return err
}
} else if methodResp[1] != 0x00 {
return fmt.Errorf("unsupported SOCKS5 auth method: %d", methodResp[1])
}
// Connection request
addr := dst.Addr()
var addrBytes []byte
if addr.Is4() {
a4 := addr.As4()
addrBytes = append([]byte{0x01}, a4[:]...) // IPv4
} else {
a16 := addr.As16()
addrBytes = append([]byte{0x04}, a16[:]...) // IPv6
}
port := dst.Port()
connectReq := append([]byte{0x05, 0x01, 0x00}, addrBytes...)
connectReq = append(connectReq, byte(port>>8), byte(port))
if _, err := conn.Write(connectReq); err != nil {
return fmt.Errorf("send connect request: %w", err)
}
// Read response (minimum 10 bytes for IPv4)
var respHeader [4]byte
if _, err := io.ReadFull(conn, respHeader[:]); err != nil {
return fmt.Errorf("read connect response: %w", err)
}
if respHeader[1] != 0x00 {
return fmt.Errorf("SOCKS5 connect failed: status %d", respHeader[1])
}
// Skip bound address
switch respHeader[3] {
case 0x01: // IPv4
var skip [4 + 2]byte
if _, err := io.ReadFull(conn, skip[:]); err != nil {
return fmt.Errorf("read SOCKS5 bound IPv4 address: %w", err)
}
case 0x04: // IPv6
var skip [16 + 2]byte
if _, err := io.ReadFull(conn, skip[:]); err != nil {
return fmt.Errorf("read SOCKS5 bound IPv6 address: %w", err)
}
case 0x03: // Domain
var dLen [1]byte
if _, err := io.ReadFull(conn, dLen[:]); err != nil {
return fmt.Errorf("read domain length: %w", err)
}
skip := make([]byte, int(dLen[0])+2)
if _, err := io.ReadFull(conn, skip); err != nil {
return fmt.Errorf("read SOCKS5 bound domain address: %w", err)
}
}
return nil
}
func socks5Auth(conn net.Conn, userinfo *url.Userinfo) error {
if userinfo == nil {
return fmt.Errorf("SOCKS5 auth required but no credentials provided")
}
user := userinfo.Username()
pass, _ := userinfo.Password()
// Username/password auth (RFC 1929)
auth := []byte{0x01, byte(len(user))}
auth = append(auth, []byte(user)...)
auth = append(auth, byte(len(pass)))
auth = append(auth, []byte(pass)...)
if _, err := conn.Write(auth); err != nil {
return fmt.Errorf("send auth: %w", err)
}
var resp [2]byte
if _, err := io.ReadFull(conn, resp[:]); err != nil {
return fmt.Errorf("read auth response: %w", err)
}
if resp[1] != 0x00 {
return fmt.Errorf("SOCKS5 auth failed: status %d", resp[1])
}
return nil
}
func basicAuth(userinfo *url.Userinfo) string {
user := userinfo.Username()
pass, _ := userinfo.Password()
return base64.StdEncoding.EncodeToString([]byte(user + ":" + pass))
}

532
client/inspect/http.go Normal file
View File

@@ -0,0 +1,532 @@
package inspect
import (
"bufio"
"context"
"fmt"
"io"
"net"
"net/http"
"net/netip"
"strings"
"sync"
"time"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
headerUpgrade = "Upgrade"
valueWebSocket = "websocket"
)
// inspectHTTP runs the HTTP inspection pipeline on decrypted traffic.
// It handles HTTP/1.1 (request-response loop), HTTP/2 (via Go stdlib reverse proxy),
// and WebSocket upgrade detection.
func (p *Proxy) inspectHTTP(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, proto string) error {
if proto == "h2" {
return p.inspectH2(ctx, client, remote, dst, sni, src)
}
return p.inspectH1(ctx, client, remote, dst, sni, src)
}
// inspectH1 handles HTTP/1.1 request-response inspection in a loop.
func (p *Proxy) inspectH1(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
clientReader := bufio.NewReader(client)
remoteReader := bufio.NewReader(remote)
for {
if ctx.Err() != nil {
return ctx.Err()
}
// Set idle timeout between requests to prevent connection hogging.
if err := client.SetReadDeadline(time.Now().Add(idleTimeout)); err != nil {
return fmt.Errorf("set idle deadline: %w", err)
}
req, err := http.ReadRequest(clientReader)
if err != nil {
if isClosedErr(err) {
return nil
}
return fmt.Errorf("read HTTP request: %w", err)
}
if err := client.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
// Re-evaluate rules based on Host header if SNI was empty
host := hostFromRequest(req, sni)
// Domain fronting: Host header doesn't match TLS SNI
if isDomainFronting(req, sni) {
p.log.Debugf("domain fronting detected: SNI=%s Host=%s", sni.PunycodeString(), host.PunycodeString())
writeBlockResponse(client, req, host)
return ErrBlocked
}
proto := ProtoHTTP
if isWebSocketUpgrade(req) {
proto = ProtoWebSocket
}
action := p.evaluateAction(src.IP, host, dst, proto, req.URL.Path)
if action == ActionBlock {
p.log.Debugf("block: HTTP %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
writeBlockResponse(client, req, host)
return ErrBlocked
}
p.log.Tracef("allow: HTTP %s %s (host=%s, action=%s)", req.Method, req.URL.Path, host.PunycodeString(), action)
// ICAP REQMOD: send request for inspection.
// Snapshot ICAP client under lock to avoid use-after-close races.
p.mu.RLock()
icap := p.icap
p.mu.RUnlock()
if icap != nil {
modified, err := icap.ReqMod(req)
if err != nil {
p.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
// Fail-closed: block on ICAP error
writeBlockResponse(client, req, host)
return fmt.Errorf("ICAP REQMOD: %w", err)
}
req = modified
}
if isWebSocketUpgrade(req) {
return p.handleWebSocket(ctx, req, client, clientReader, remote, remoteReader)
}
removeHopByHopHeaders(req.Header)
if err := req.Write(remote); err != nil {
return fmt.Errorf("forward request: %w", err)
}
resp, err := http.ReadResponse(remoteReader, req)
if err != nil {
return fmt.Errorf("read HTTP response: %w", err)
}
// ICAP RESPMOD: send response for inspection
if icap != nil {
modified, err := icap.RespMod(req, resp)
if err != nil {
p.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close resp body: %v", err)
}
writeBlockResponse(client, req, host)
return fmt.Errorf("ICAP RESPMOD: %w", err)
}
resp = modified
}
removeHopByHopHeaders(resp.Header)
if err := resp.Write(client); err != nil {
if closeErr := resp.Body.Close(); closeErr != nil {
p.log.Debugf("close resp body: %v", closeErr)
}
return fmt.Errorf("forward response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close resp body: %v", err)
}
// Connection: close means we're done
if resp.Close || req.Close {
return nil
}
}
}
// inspectH2 proxies HTTP/2 traffic using Go's http stack.
// Client and remote are already-established TLS connections with h2 negotiated.
func (p *Proxy) inspectH2(ctx context.Context, client, remote net.Conn, dst netip.AddrPort, sni domain.Domain, src SourceInfo) error {
// For h2 MITM inspection, we use a local http.Server reading from the client
// connection and an http.Transport writing to the remote connection.
//
// The transport is configured to use the existing TLS connection to the
// real server. The handler inspects each request/response pair.
transport := &http.Transport{
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return remote, nil
},
DialTLSContext: func(_ context.Context, _, _ string) (net.Conn, error) {
return remote, nil
},
ForceAttemptHTTP2: true,
}
handler := &h2InspectionHandler{
proxy: p,
transport: transport,
dst: dst,
sni: sni,
src: src,
}
server := &http.Server{
Handler: handler,
}
// Serve the single client connection.
// ServeConn blocks until the connection is done.
errCh := make(chan error, 1)
go func() {
// http.Server doesn't have a direct ServeConn for h2,
// so we use Serve with a single-connection listener.
ln := &singleConnListener{conn: client}
errCh <- server.Serve(ln)
}()
select {
case <-ctx.Done():
if err := server.Close(); err != nil {
p.log.Debugf("close h2 server: %v", err)
}
return ctx.Err()
case err := <-errCh:
if err == http.ErrServerClosed {
return nil
}
return err
}
}
// h2InspectionHandler inspects each HTTP/2 request/response pair.
type h2InspectionHandler struct {
proxy *Proxy
transport http.RoundTripper
dst netip.AddrPort
sni domain.Domain
src SourceInfo
}
func (h *h2InspectionHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
host := hostFromRequest(req, h.sni)
if isDomainFronting(req, h.sni) {
h.proxy.log.Debugf("domain fronting detected: SNI=%s Host=%s", h.sni.PunycodeString(), host.PunycodeString())
writeBlockPage(w, host)
return
}
action := h.proxy.evaluateAction(h.src.IP, host, h.dst, ProtoH2, req.URL.Path)
if action == ActionBlock {
h.proxy.log.Debugf("block: H2 %s %s (host=%s)", req.Method, req.URL.Path, host.PunycodeString())
writeBlockPage(w, host)
return
}
// ICAP REQMOD
if h.proxy.icap != nil {
modified, err := h.proxy.icap.ReqMod(req)
if err != nil {
h.proxy.log.Debugf("ICAP REQMOD error for %s: %v", host.PunycodeString(), err)
writeBlockPage(w, host)
return
}
req = modified
}
// Forward to upstream
req.URL.Scheme = "https"
req.URL.Host = h.sni.PunycodeString()
req.RequestURI = ""
resp, err := h.transport.RoundTrip(req)
if err != nil {
h.proxy.log.Debugf("h2 upstream error for %s: %v", host.PunycodeString(), err)
http.Error(w, "Bad Gateway", http.StatusBadGateway)
return
}
defer func() {
if err := resp.Body.Close(); err != nil {
h.proxy.log.Debugf("close h2 resp body: %v", err)
}
}()
// ICAP RESPMOD
if h.proxy.icap != nil {
modified, err := h.proxy.icap.RespMod(req, resp)
if err != nil {
h.proxy.log.Debugf("ICAP RESPMOD error for %s: %v", host.PunycodeString(), err)
writeBlockPage(w, host)
return
}
resp = modified
}
// Copy response headers and body
for k, vals := range resp.Header {
for _, v := range vals {
w.Header().Add(k, v)
}
}
w.WriteHeader(resp.StatusCode)
if _, err := io.Copy(w, resp.Body); err != nil {
h.proxy.log.Debugf("h2 response copy error: %v", err)
}
}
// handleWebSocket completes the WebSocket upgrade and relays frames bidirectionally.
func (p *Proxy) handleWebSocket(ctx context.Context, req *http.Request, client io.ReadWriter, clientReader *bufio.Reader, remote io.ReadWriter, remoteReader *bufio.Reader) error {
if err := req.Write(remote); err != nil {
return fmt.Errorf("forward WebSocket upgrade: %w", err)
}
resp, err := http.ReadResponse(remoteReader, req)
if err != nil {
return fmt.Errorf("read WebSocket upgrade response: %w", err)
}
if err := resp.Write(client); err != nil {
if closeErr := resp.Body.Close(); closeErr != nil {
p.log.Debugf("close ws resp body: %v", closeErr)
}
return fmt.Errorf("forward WebSocket upgrade response: %w", err)
}
if err := resp.Body.Close(); err != nil {
p.log.Debugf("close ws resp body: %v", err)
}
if resp.StatusCode != http.StatusSwitchingProtocols {
return fmt.Errorf("WebSocket upgrade rejected: status %d", resp.StatusCode)
}
p.log.Tracef("allow: WebSocket upgrade for %s", req.Host)
// Relay WebSocket frames bidirectionally.
// clientReader/remoteReader may have buffered data.
clientConn := mergeReadWriter(clientReader, client)
remoteConn := mergeReadWriter(remoteReader, remote)
return relayRW(ctx, clientConn, remoteConn)
}
// hostFromRequest extracts a domain.Domain from the HTTP request Host header,
// falling back to the SNI if Host is empty or an IP.
func hostFromRequest(req *http.Request, fallback domain.Domain) domain.Domain {
host := req.Host
if host == "" {
return fallback
}
// Strip port if present
if h, _, err := net.SplitHostPort(host); err == nil {
host = h
}
// If it's an IP address, use the SNI fallback
if _, err := netip.ParseAddr(host); err == nil {
return fallback
}
d, err := domain.FromString(host)
if err != nil {
return fallback
}
return d
}
// isDomainFronting detects domain fronting: the Host header doesn't match the
// SNI used during the TLS handshake. Only meaningful when SNI is non-empty
// (i.e., we're in MITM mode and know the original SNI).
func isDomainFronting(req *http.Request, sni domain.Domain) bool {
if sni == "" {
return false
}
host := hostFromRequest(req, "")
if host == "" {
return false
}
// Host should match SNI or be a subdomain of SNI
if host == sni {
return false
}
// Allow www.example.com when SNI is example.com
sniStr := sni.PunycodeString()
hostStr := host.PunycodeString()
if strings.HasSuffix(hostStr, "."+sniStr) {
return false
}
return true
}
func isWebSocketUpgrade(req *http.Request) bool {
return strings.EqualFold(req.Header.Get(headerUpgrade), valueWebSocket)
}
// writeBlockPage writes the styled HTML block page to an http.ResponseWriter (H2 path).
func writeBlockPage(w http.ResponseWriter, host domain.Domain) {
hostname := host.PunycodeString()
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
w.Header().Set("Content-Type", "text/html; charset=utf-8")
w.Header().Set("Cache-Control", "no-store")
w.WriteHeader(http.StatusForbidden)
io.WriteString(w, body)
}
func writeBlockResponse(w io.Writer, _ *http.Request, host domain.Domain) {
hostname := host.PunycodeString()
body := fmt.Sprintf(blockPageHTML, hostname, hostname)
resp := &http.Response{
StatusCode: http.StatusForbidden,
ProtoMajor: 1,
ProtoMinor: 1,
Header: make(http.Header),
ContentLength: int64(len(body)),
Body: io.NopCloser(strings.NewReader(body)),
}
resp.Header.Set("Content-Type", "text/html; charset=utf-8")
resp.Header.Set("Connection", "close")
resp.Header.Set("Cache-Control", "no-store")
_ = resp.Write(w)
}
// blockPageHTML is the self-contained HTML block page.
// Uses NetBird dark theme with orange accent. Two format args: page title domain, displayed domain.
const blockPageHTML = `<!DOCTYPE html>
<html lang="en">
<head>
<meta charset="utf-8">
<meta name="viewport" content="width=device-width,initial-scale=1">
<title>Blocked - %s</title>
<style>
*{margin:0;padding:0;box-sizing:border-box}
body{background:#181a1d;color:#d1d5db;font-family:-apple-system,BlinkMacSystemFont,"Segoe UI",Roboto,sans-serif;min-height:100vh;display:flex;align-items:center;justify-content:center}
.c{text-align:center;max-width:460px;padding:2rem}
.shield{width:56px;height:56px;margin:0 auto 1.5rem;border-radius:16px;background:#2b2f33;display:flex;align-items:center;justify-content:center}
.shield svg{width:28px;height:28px;color:#f68330}
.code{font-size:.8rem;font-weight:500;color:#f68330;font-family:ui-monospace,monospace;letter-spacing:.05em;margin-bottom:.5rem}
h1{font-size:1.5rem;font-weight:600;color:#f4f4f5;margin-bottom:.5rem}
p{font-size:.95rem;line-height:1.5;color:#9ca3af;margin-bottom:1.75rem}
.domain{display:inline-block;background:#25282d;border:1px solid #32363d;border-radius:6px;padding:.15rem .5rem;font-family:ui-monospace,monospace;font-size:.85rem;color:#d1d5db}
.footer{font-size:.7rem;color:#6b7280;margin-top:2rem;letter-spacing:.03em}
.footer a{color:#6b7280;text-decoration:none}
.footer a:hover{color:#9ca3af}
</style>
</head>
<body>
<div class="c">
<div class="shield"><svg xmlns="http://www.w3.org/2000/svg" fill="none" viewBox="0 0 24 24" stroke-width="1.5" stroke="currentColor"><path stroke-linecap="round" stroke-linejoin="round" d="M12 9v3.75m0-10.036A11.959 11.959 0 0 1 3.598 6 11.99 11.99 0 0 0 3 9.75c0 5.592 3.824 10.29 9 11.622 5.176-1.332 9-6.03 9-11.622 0-1.31-.21-2.571-.598-3.751A11.96 11.96 0 0 0 12 3.714Z"/></svg></div>
<div class="code">403 BLOCKED</div>
<h1>Access Denied</h1>
<p>This connection to <span class="domain">%s</span> has been blocked by your organization's network policy.</p>
<div class="footer">Protected by <a href="https://netbird.io" target="_blank" rel="noopener">NetBird</a></div>
</div>
</body>
</html>`
// singleConnListener is a net.Listener that yields a single connection.
type singleConnListener struct {
conn net.Conn
once sync.Once
ch chan struct{}
}
func (l *singleConnListener) Accept() (net.Conn, error) {
var accepted bool
l.once.Do(func() {
l.ch = make(chan struct{})
accepted = true
})
if accepted {
return l.conn, nil
}
// Block until Close
<-l.ch
return nil, net.ErrClosed
}
func (l *singleConnListener) Close() error {
l.once.Do(func() {
l.ch = make(chan struct{})
})
select {
case <-l.ch:
default:
close(l.ch)
}
return nil
}
func (l *singleConnListener) Addr() net.Addr {
return l.conn.LocalAddr()
}
type readWriter struct {
io.Reader
io.Writer
}
func mergeReadWriter(r io.Reader, w io.Writer) io.ReadWriter {
return &readWriter{Reader: r, Writer: w}
}
// relayRW copies data bidirectionally between two ReadWriters.
func relayRW(ctx context.Context, a, b io.ReadWriter) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(b, a)
cancel()
errCh <- err
}()
go func() {
_, err := io.Copy(a, b)
cancel()
errCh <- err
}()
var firstErr error
for range 2 {
if err := <-errCh; err != nil && firstErr == nil {
if !isClosedErr(err) {
firstErr = err
}
}
}
return firstErr
}
// hopByHopHeaders are HTTP/1.1 headers that apply to a single connection
// and must not be forwarded by a proxy (RFC 7230, Section 6.1).
var hopByHopHeaders = []string{
"Connection",
"Keep-Alive",
"Proxy-Authenticate",
"Proxy-Authorization",
"TE",
"Trailers",
"Transfer-Encoding",
"Upgrade",
}
// removeHopByHopHeaders strips hop-by-hop headers from h.
// Also removes headers listed in the Connection header value.
func removeHopByHopHeaders(h http.Header) {
// First, remove any headers named in the Connection header
for _, connHeader := range h["Connection"] {
for _, name := range strings.Split(connHeader, ",") {
h.Del(strings.TrimSpace(name))
}
}
for _, name := range hopByHopHeaders {
h.Del(name)
}
}

479
client/inspect/icap.go Normal file
View File

@@ -0,0 +1,479 @@
package inspect
import (
"bufio"
"bytes"
"fmt"
"io"
"net"
"net/http"
"net/textproto"
"net/url"
"strconv"
"strings"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
const (
icapVersion = "ICAP/1.0"
icapDefaultPort = "1344"
icapConnTimeout = 30 * time.Second
icapRWTimeout = 60 * time.Second
icapMaxPoolSize = 8
icapIdleTimeout = 60 * time.Second
icapMaxRespSize = 4 * 1024 * 1024 // 4 MB
)
// ICAPClient implements an ICAP (RFC 3507) client with persistent connection pooling.
type ICAPClient struct {
reqModURL *url.URL
respModURL *url.URL
pool chan *icapConn
mu sync.Mutex
log *log.Entry
maxPool int
}
type icapConn struct {
conn net.Conn
reader *bufio.Reader
lastUse time.Time
}
// NewICAPClient creates an ICAP client. Either or both URLs may be nil
// to disable that mode.
func NewICAPClient(logger *log.Entry, cfg *ICAPConfig) *ICAPClient {
maxPool := cfg.MaxConnections
if maxPool <= 0 {
maxPool = icapMaxPoolSize
}
return &ICAPClient{
reqModURL: cfg.ReqModURL,
respModURL: cfg.RespModURL,
pool: make(chan *icapConn, maxPool),
log: logger,
maxPool: maxPool,
}
}
// ReqMod sends an HTTP request to the ICAP REQMOD service for inspection.
// Returns the (possibly modified) request, or the original if ICAP returns 204.
// Returns nil, nil if REQMOD is not configured.
func (c *ICAPClient) ReqMod(req *http.Request) (*http.Request, error) {
if c.reqModURL == nil {
return req, nil
}
var reqBuf bytes.Buffer
if err := req.Write(&reqBuf); err != nil {
return nil, fmt.Errorf("serialize request: %w", err)
}
respBody, err := c.send("REQMOD", c.reqModURL, reqBuf.Bytes(), nil)
if err != nil {
return nil, err
}
if respBody == nil {
return req, nil
}
modified, err := http.ReadRequest(bufio.NewReader(bytes.NewReader(respBody)))
if err != nil {
return nil, fmt.Errorf("parse ICAP modified request: %w", err)
}
return modified, nil
}
// RespMod sends an HTTP response to the ICAP RESPMOD service for inspection.
// Returns the (possibly modified) response, or the original if ICAP returns 204.
// Returns nil, nil if RESPMOD is not configured.
func (c *ICAPClient) RespMod(req *http.Request, resp *http.Response) (*http.Response, error) {
if c.respModURL == nil {
return resp, nil
}
var reqBuf bytes.Buffer
if err := req.Write(&reqBuf); err != nil {
return nil, fmt.Errorf("serialize request: %w", err)
}
var respBuf bytes.Buffer
if err := resp.Write(&respBuf); err != nil {
return nil, fmt.Errorf("serialize response: %w", err)
}
respBody, err := c.send("RESPMOD", c.respModURL, reqBuf.Bytes(), respBuf.Bytes())
if err != nil {
return nil, err
}
if respBody == nil {
// 204 No Content: ICAP server didn't modify the response.
// Reconstruct from the buffered copy since resp.Body was consumed by Write.
reconstructed, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBuf.Bytes())), req)
if err != nil {
return nil, fmt.Errorf("reconstruct response after ICAP 204: %w", err)
}
return reconstructed, nil
}
modified, err := http.ReadResponse(bufio.NewReader(bytes.NewReader(respBody)), req)
if err != nil {
return nil, fmt.Errorf("parse ICAP modified response: %w", err)
}
return modified, nil
}
// Close drains and closes all pooled connections.
func (c *ICAPClient) Close() {
close(c.pool)
for ic := range c.pool {
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close ICAP connection: %v", err)
}
}
}
// send executes an ICAP request and returns the encapsulated body from the response.
// Returns nil body for 204 No Content (no modification).
// Retries once on stale pooled connection (EOF on read).
func (c *ICAPClient) send(method string, serviceURL *url.URL, reqData, respData []byte) ([]byte, error) {
statusCode, headers, body, err := c.trySend(method, serviceURL, reqData, respData)
if err != nil && isStaleConnErr(err) {
// Retry once with a fresh connection (stale pool entry).
c.log.Debugf("ICAP %s: retrying after stale connection: %v", method, err)
statusCode, headers, body, err = c.trySend(method, serviceURL, reqData, respData)
}
if err != nil {
return nil, err
}
switch statusCode {
case 204:
return nil, nil
case 200:
return body, nil
default:
c.log.Debugf("ICAP %s returned status %d, headers: %v", method, statusCode, headers)
return nil, fmt.Errorf("ICAP %s: status %d", method, statusCode)
}
}
func (c *ICAPClient) trySend(method string, serviceURL *url.URL, reqData, respData []byte) (int, textproto.MIMEHeader, []byte, error) {
ic, err := c.getConn(serviceURL)
if err != nil {
return 0, nil, nil, fmt.Errorf("get ICAP connection: %w", err)
}
if err := c.writeRequest(ic, method, serviceURL, reqData, respData); err != nil {
if closeErr := ic.conn.Close(); closeErr != nil {
c.log.Debugf("close ICAP conn after write error: %v", closeErr)
}
return 0, nil, nil, fmt.Errorf("write ICAP %s: %w", method, err)
}
statusCode, headers, body, err := c.readResponse(ic)
if err != nil {
if closeErr := ic.conn.Close(); closeErr != nil {
c.log.Debugf("close ICAP conn after read error: %v", closeErr)
}
return 0, nil, nil, fmt.Errorf("read ICAP response: %w", err)
}
c.putConn(ic)
return statusCode, headers, body, nil
}
func isStaleConnErr(err error) bool {
if err == nil {
return false
}
s := err.Error()
return strings.Contains(s, "EOF") || strings.Contains(s, "broken pipe") || strings.Contains(s, "connection reset")
}
func (c *ICAPClient) writeRequest(ic *icapConn, method string, serviceURL *url.URL, reqData, respData []byte) error {
if err := ic.conn.SetWriteDeadline(time.Now().Add(icapRWTimeout)); err != nil {
return fmt.Errorf("set write deadline: %w", err)
}
// For RESPMOD, split the serialized HTTP response into headers and body.
// The body must be sent chunked per RFC 3507.
var respHdr, respBody []byte
if respData != nil {
if idx := bytes.Index(respData, []byte("\r\n\r\n")); idx >= 0 {
respHdr = respData[:idx+4] // include the \r\n\r\n separator
respBody = respData[idx+4:]
} else {
respHdr = respData
}
}
var buf bytes.Buffer
// Request line
fmt.Fprintf(&buf, "%s %s %s\r\n", method, serviceURL.String(), icapVersion)
// Headers
host := serviceURL.Host
fmt.Fprintf(&buf, "Host: %s\r\n", host)
fmt.Fprintf(&buf, "Connection: keep-alive\r\n")
fmt.Fprintf(&buf, "Allow: 204\r\n")
// Build Encapsulated header
offset := 0
var encapParts []string
if reqData != nil {
encapParts = append(encapParts, fmt.Sprintf("req-hdr=%d", offset))
offset += len(reqData)
}
if respHdr != nil {
encapParts = append(encapParts, fmt.Sprintf("res-hdr=%d", offset))
offset += len(respHdr)
}
if len(respBody) > 0 {
encapParts = append(encapParts, fmt.Sprintf("res-body=%d", offset))
} else {
encapParts = append(encapParts, fmt.Sprintf("null-body=%d", offset))
}
fmt.Fprintf(&buf, "Encapsulated: %s\r\n", strings.Join(encapParts, ", "))
fmt.Fprintf(&buf, "\r\n")
// Encapsulated sections
if reqData != nil {
buf.Write(reqData)
}
if respHdr != nil {
buf.Write(respHdr)
}
// Body in chunked encoding (only when there is an actual body section).
// Per RFC 3507 Section 4.4.1, null-body must not include any entity data.
if len(respBody) > 0 {
fmt.Fprintf(&buf, "%x\r\n", len(respBody))
buf.Write(respBody)
buf.WriteString("\r\n")
buf.WriteString("0\r\n\r\n")
}
_, err := ic.conn.Write(buf.Bytes())
return err
}
func (c *ICAPClient) readResponse(ic *icapConn) (int, textproto.MIMEHeader, []byte, error) {
if err := ic.conn.SetReadDeadline(time.Now().Add(icapRWTimeout)); err != nil {
return 0, nil, nil, fmt.Errorf("set read deadline: %w", err)
}
tp := textproto.NewReader(ic.reader)
// Status line: "ICAP/1.0 200 OK"
statusLine, err := tp.ReadLine()
if err != nil {
return 0, nil, nil, fmt.Errorf("read status line: %w", err)
}
statusCode, err := parseICAPStatus(statusLine)
if err != nil {
return 0, nil, nil, err
}
// Headers
headers, err := tp.ReadMIMEHeader()
if err != nil {
return statusCode, nil, nil, fmt.Errorf("read ICAP headers: %w", err)
}
if statusCode == 204 {
return statusCode, headers, nil, nil
}
// Read encapsulated body based on Encapsulated header
body, err := c.readEncapsulatedBody(ic.reader, headers)
if err != nil {
return statusCode, headers, nil, fmt.Errorf("read encapsulated body: %w", err)
}
return statusCode, headers, body, nil
}
func (c *ICAPClient) readEncapsulatedBody(r *bufio.Reader, headers textproto.MIMEHeader) ([]byte, error) {
encap := headers.Get("Encapsulated")
if encap == "" {
return nil, nil
}
// Find the body offset from the Encapsulated header.
// The last section with a non-zero offset is the body.
// Read everything from the reader as the encapsulated content.
var totalSize int
parts := strings.Split(encap, ",")
for _, part := range parts {
part = strings.TrimSpace(part)
eqIdx := strings.Index(part, "=")
if eqIdx < 0 {
continue
}
offset, err := strconv.Atoi(strings.TrimSpace(part[eqIdx+1:]))
if err != nil {
continue
}
if offset > totalSize {
totalSize = offset
}
}
// Read all available encapsulated data (headers + body)
// The body section uses chunked encoding per RFC 3507
var buf bytes.Buffer
if totalSize > 0 {
// Read the header sections (everything before the body offset)
headerBytes := make([]byte, totalSize)
if _, err := io.ReadFull(r, headerBytes); err != nil {
return nil, fmt.Errorf("read encapsulated headers: %w", err)
}
buf.Write(headerBytes)
}
// Read chunked body
chunked := newChunkedReader(r)
body, err := io.ReadAll(io.LimitReader(chunked, icapMaxRespSize))
if err != nil {
return nil, fmt.Errorf("read chunked body: %w", err)
}
buf.Write(body)
return buf.Bytes(), nil
}
func (c *ICAPClient) getConn(serviceURL *url.URL) (*icapConn, error) {
// Try to get a pooled connection
for {
select {
case ic := <-c.pool:
if time.Since(ic.lastUse) > icapIdleTimeout {
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close idle ICAP connection: %v", err)
}
continue
}
return ic, nil
default:
return c.dialConn(serviceURL)
}
}
}
func (c *ICAPClient) putConn(ic *icapConn) {
c.mu.Lock()
defer c.mu.Unlock()
ic.lastUse = time.Now()
select {
case c.pool <- ic:
default:
// Pool full, close connection.
if err := ic.conn.Close(); err != nil {
c.log.Debugf("close excess ICAP connection: %v", err)
}
}
}
func (c *ICAPClient) dialConn(serviceURL *url.URL) (*icapConn, error) {
host := serviceURL.Host
if _, _, err := net.SplitHostPort(host); err != nil {
host = net.JoinHostPort(host, icapDefaultPort)
}
conn, err := net.DialTimeout("tcp", host, icapConnTimeout)
if err != nil {
return nil, fmt.Errorf("dial ICAP %s: %w", host, err)
}
return &icapConn{
conn: conn,
reader: bufio.NewReader(conn),
lastUse: time.Now(),
}, nil
}
func parseICAPStatus(line string) (int, error) {
// "ICAP/1.0 200 OK"
parts := strings.SplitN(line, " ", 3)
if len(parts) < 2 {
return 0, fmt.Errorf("malformed ICAP status line: %q", line)
}
code, err := strconv.Atoi(parts[1])
if err != nil {
return 0, fmt.Errorf("parse ICAP status code %q: %w", parts[1], err)
}
return code, nil
}
// chunkedReader reads ICAP chunked encoding (same as HTTP chunked, terminated by "0\r\n\r\n").
type chunkedReader struct {
r *bufio.Reader
remaining int
done bool
}
func newChunkedReader(r *bufio.Reader) *chunkedReader {
return &chunkedReader{r: r}
}
func (cr *chunkedReader) Read(p []byte) (int, error) {
if cr.done {
return 0, io.EOF
}
if cr.remaining == 0 {
// Read chunk size line
line, err := cr.r.ReadString('\n')
if err != nil {
return 0, err
}
line = strings.TrimSpace(line)
// Strip any chunk extensions
if idx := strings.Index(line, ";"); idx >= 0 {
line = line[:idx]
}
size, err := strconv.ParseInt(line, 16, 64)
if err != nil {
return 0, fmt.Errorf("parse chunk size %q: %w", line, err)
}
if size == 0 {
cr.done = true
// Consume trailing \r\n
_, _ = cr.r.ReadString('\n')
return 0, io.EOF
}
if size < 0 || size > icapMaxRespSize {
return 0, fmt.Errorf("chunk size %d out of range (max %d)", size, icapMaxRespSize)
}
cr.remaining = int(size)
}
toRead := len(p)
if toRead > cr.remaining {
toRead = cr.remaining
}
n, err := cr.r.Read(p[:toRead])
cr.remaining -= n
if cr.remaining == 0 {
// Consume chunk-terminating \r\n
_, _ = cr.r.ReadString('\n')
}
return n, err
}

View File

@@ -0,0 +1,21 @@
//go:build !linux
package inspect
import (
"fmt"
"net"
"net/netip"
log "github.com/sirupsen/logrus"
)
// newTPROXYListener is not supported on non-Linux platforms.
func newTPROXYListener(_ *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
return nil, fmt.Errorf("TPROXY listener not supported on this platform (requested %s)", addr)
}
// getOriginalDst is not supported on non-Linux platforms.
func getOriginalDst(_ net.Conn) (netip.AddrPort, error) {
return netip.AddrPort{}, fmt.Errorf("SO_ORIGINAL_DST not supported on this platform")
}

View File

@@ -0,0 +1,89 @@
package inspect
import (
"fmt"
"net"
"net/netip"
"unsafe"
log "github.com/sirupsen/logrus"
"golang.org/x/sys/unix"
)
// newTPROXYListener creates a TCP listener for the transparent proxy.
// After nftables REDIRECT, accepted connections have LocalAddr = WG_IP:proxy_port.
// The original destination is retrieved via getsockopt(SO_ORIGINAL_DST).
func newTPROXYListener(logger *log.Entry, addr netip.AddrPort, _ netip.Prefix) (net.Listener, error) {
ln, err := net.Listen("tcp", addr.String())
if err != nil {
return nil, fmt.Errorf("listen on %s: %w", addr, err)
}
logger.Infof("inspect: listener started on %s", ln.Addr())
return ln, nil
}
// getOriginalDst reads the original destination from conntrack via SO_ORIGINAL_DST.
// This is set by the kernel when the connection was REDIRECT'd/DNAT'd.
// Tries IPv4 first, then falls back to IPv6 (IP6T_SO_ORIGINAL_DST).
func getOriginalDst(conn net.Conn) (netip.AddrPort, error) {
tc, ok := conn.(*net.TCPConn)
if !ok {
return netip.AddrPort{}, fmt.Errorf("not a TCPConn")
}
raw, err := tc.SyscallConn()
if err != nil {
return netip.AddrPort{}, fmt.Errorf("get syscall conn: %w", err)
}
var origDst netip.AddrPort
var sockErr error
if err := raw.Control(func(fd uintptr) {
// Try IPv4 first (SO_ORIGINAL_DST = 80)
var sa4 unix.RawSockaddrInet4
sa4Len := uint32(unsafe.Sizeof(sa4))
_, _, errno := unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
unix.SOL_IP,
80, // SO_ORIGINAL_DST
uintptr(unsafe.Pointer(&sa4)),
uintptr(unsafe.Pointer(&sa4Len)),
0,
)
if errno == 0 {
addr := netip.AddrFrom4(sa4.Addr)
port := uint16(sa4.Port>>8) | uint16(sa4.Port<<8)
origDst = netip.AddrPortFrom(addr.Unmap(), port)
return
}
// Fall back to IPv6 (IP6T_SO_ORIGINAL_DST = 80 on SOL_IPV6)
var sa6 unix.RawSockaddrInet6
sa6Len := uint32(unsafe.Sizeof(sa6))
_, _, errno = unix.Syscall6(
unix.SYS_GETSOCKOPT,
fd,
unix.SOL_IPV6,
80, // IP6T_SO_ORIGINAL_DST
uintptr(unsafe.Pointer(&sa6)),
uintptr(unsafe.Pointer(&sa6Len)),
0,
)
if errno != 0 {
sockErr = fmt.Errorf("getsockopt SO_ORIGINAL_DST (v4 and v6): %w", errno)
return
}
addr := netip.AddrFrom16(sa6.Addr)
port := uint16(sa6.Port>>8) | uint16(sa6.Port<<8)
origDst = netip.AddrPortFrom(addr.Unmap(), port)
}); err != nil {
return netip.AddrPort{}, fmt.Errorf("control raw conn: %w", err)
}
if sockErr != nil {
return netip.AddrPort{}, sockErr
}
return origDst, nil
}

200
client/inspect/mitm.go Normal file
View File

@@ -0,0 +1,200 @@
package inspect
import (
"crypto"
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"fmt"
"math/big"
mrand "math/rand/v2"
"sync"
"time"
)
const (
// certCacheSize is the maximum number of cached leaf certificates.
certCacheSize = 1024
// certTTL is how long generated certificates remain valid.
certTTL = 24 * time.Hour
)
// certCache is a bounded LRU cache for generated TLS certificates.
type certCache struct {
mu sync.Mutex
entries map[string]*certEntry
// order tracks LRU eviction, most recent at end.
order []string
maxSize int
}
type certEntry struct {
cert *tls.Certificate
expiresAt time.Time
}
func newCertCache(maxSize int) *certCache {
return &certCache{
entries: make(map[string]*certEntry, maxSize),
order: make([]string, 0, maxSize),
maxSize: maxSize,
}
}
func (c *certCache) get(hostname string) (*tls.Certificate, bool) {
c.mu.Lock()
defer c.mu.Unlock()
entry, ok := c.entries[hostname]
if !ok {
return nil, false
}
if time.Now().After(entry.expiresAt) {
c.removeLocked(hostname)
return nil, false
}
// Move to end (most recently used)
c.touchLocked(hostname)
return entry.cert, true
}
func (c *certCache) put(hostname string, cert *tls.Certificate) {
c.mu.Lock()
defer c.mu.Unlock()
// Jitter the TTL by +/- 20% to prevent thundering herd on expiry.
jitter := time.Duration(float64(certTTL) * (0.8 + 0.4*mrand.Float64()))
if _, exists := c.entries[hostname]; exists {
c.entries[hostname] = &certEntry{
cert: cert,
expiresAt: time.Now().Add(jitter),
}
c.touchLocked(hostname)
return
}
// Evict oldest if at capacity
for len(c.entries) >= c.maxSize && len(c.order) > 0 {
c.removeLocked(c.order[0])
}
c.entries[hostname] = &certEntry{
cert: cert,
expiresAt: time.Now().Add(jitter),
}
c.order = append(c.order, hostname)
}
func (c *certCache) touchLocked(hostname string) {
for i, h := range c.order {
if h == hostname {
c.order = append(c.order[:i], c.order[i+1:]...)
c.order = append(c.order, hostname)
return
}
}
}
func (c *certCache) removeLocked(hostname string) {
delete(c.entries, hostname)
for i, h := range c.order {
if h == hostname {
c.order = append(c.order[:i], c.order[i+1:]...)
return
}
}
}
// CertProvider generates TLS certificates on the fly, signed by a CA.
// Generated certificates are cached in an LRU cache.
type CertProvider struct {
ca *x509.Certificate
caKey crypto.PrivateKey
cache *certCache
}
// NewCertProvider creates a certificate provider using the given CA.
func NewCertProvider(ca *x509.Certificate, caKey crypto.PrivateKey) *CertProvider {
return &CertProvider{
ca: ca,
caKey: caKey,
cache: newCertCache(certCacheSize),
}
}
// GetCertificate returns a TLS certificate for the given hostname,
// generating and caching one if necessary.
func (p *CertProvider) GetCertificate(hostname string) (*tls.Certificate, error) {
if cert, ok := p.cache.get(hostname); ok {
return cert, nil
}
cert, err := p.generateCert(hostname)
if err != nil {
return nil, fmt.Errorf("generate cert for %s: %w", hostname, err)
}
p.cache.put(hostname, cert)
return cert, nil
}
// GetTLSConfig returns a tls.Config that dynamically provides certificates
// for any hostname using the MITM CA.
func (p *CertProvider) GetTLSConfig() *tls.Config {
return &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
return p.GetCertificate(hello.ServerName)
},
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
}
}
func (p *CertProvider) generateCert(hostname string) (*tls.Certificate, error) {
serialNumber, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial number: %w", err)
}
now := time.Now()
template := &x509.Certificate{
SerialNumber: serialNumber,
Subject: pkix.Name{
CommonName: hostname,
},
NotBefore: now.Add(-5 * time.Minute),
NotAfter: now.Add(certTTL),
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
ExtKeyUsage: []x509.ExtKeyUsage{
x509.ExtKeyUsageServerAuth,
},
DNSNames: []string{hostname},
}
leafKey, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate leaf key: %w", err)
}
certDER, err := x509.CreateCertificate(rand.Reader, template, p.ca, &leafKey.PublicKey, p.caKey)
if err != nil {
return nil, fmt.Errorf("sign leaf certificate: %w", err)
}
leafCert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("parse generated certificate: %w", err)
}
return &tls.Certificate{
Certificate: [][]byte{certDER, p.ca.Raw},
PrivateKey: leafKey,
Leaf: leafCert,
}, nil
}

133
client/inspect/mitm_test.go Normal file
View File

@@ -0,0 +1,133 @@
package inspect
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"math/big"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func generateTestCA(t *testing.T) (*x509.Certificate, *ecdsa.PrivateKey) {
t.Helper()
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
require.NoError(t, err)
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: "Test CA",
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
require.NoError(t, err)
cert, err := x509.ParseCertificate(certDER)
require.NoError(t, err)
return cert, key
}
func TestCertProvider_GetCertificate(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert, err := provider.GetCertificate("example.com")
require.NoError(t, err)
require.NotNil(t, cert)
// Verify the leaf certificate
assert.Equal(t, "example.com", cert.Leaf.Subject.CommonName)
assert.Contains(t, cert.Leaf.DNSNames, "example.com")
// Verify chain: leaf + CA
assert.Len(t, cert.Certificate, 2)
// Verify leaf is signed by our CA
pool := x509.NewCertPool()
pool.AddCert(ca)
_, err = cert.Leaf.Verify(x509.VerifyOptions{
Roots: pool,
})
require.NoError(t, err)
}
func TestCertProvider_CachesResults(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert1, err := provider.GetCertificate("cached.example.com")
require.NoError(t, err)
cert2, err := provider.GetCertificate("cached.example.com")
require.NoError(t, err)
// Same pointer = cached
assert.Equal(t, cert1, cert2)
}
func TestCertProvider_DifferentHostsDifferentCerts(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
cert1, err := provider.GetCertificate("a.example.com")
require.NoError(t, err)
cert2, err := provider.GetCertificate("b.example.com")
require.NoError(t, err)
assert.NotEqual(t, cert1.Leaf.SerialNumber, cert2.Leaf.SerialNumber)
}
func TestCertProvider_TLSConfigHandshake(t *testing.T) {
ca, caKey := generateTestCA(t)
provider := NewCertProvider(ca, caKey)
tlsConfig := provider.GetTLSConfig()
require.NotNil(t, tlsConfig)
require.NotNil(t, tlsConfig.GetCertificate)
// Simulate a ClientHelloInfo
hello := &tls.ClientHelloInfo{
ServerName: "handshake.example.com",
}
cert, err := tlsConfig.GetCertificate(hello)
require.NoError(t, err)
assert.Equal(t, "handshake.example.com", cert.Leaf.Subject.CommonName)
}
func TestCertCache_Eviction(t *testing.T) {
cache := newCertCache(3)
for i := range 5 {
hostname := string(rune('a'+i)) + ".example.com"
cache.put(hostname, &tls.Certificate{})
}
// Only 3 should remain (c, d, e - the most recent)
assert.Len(t, cache.entries, 3)
_, ok := cache.get("a.example.com")
assert.False(t, ok, "oldest entry should be evicted")
_, ok = cache.get("b.example.com")
assert.False(t, ok, "second oldest should be evicted")
_, ok = cache.get("e.example.com")
assert.True(t, ok, "newest entry should exist")
}

109
client/inspect/peek.go Normal file
View File

@@ -0,0 +1,109 @@
package inspect
import (
"bytes"
"fmt"
"io"
"net"
)
// peekConn wraps a net.Conn with a buffer that allows reading ahead
// without consuming data. Subsequent Read calls return the buffered
// bytes first, then read from the underlying connection.
type peekConn struct {
net.Conn
buf bytes.Buffer
// peeked holds the raw bytes that were peeked, available for replay.
peeked []byte
}
// newPeekConn wraps conn for peek-ahead reading.
func newPeekConn(conn net.Conn) *peekConn {
return &peekConn{Conn: conn}
}
// Peek reads exactly n bytes from the connection without consuming them.
// The peeked bytes are replayed on subsequent Read calls.
// Peek may only be called once; calling it again returns an error.
func (c *peekConn) Peek(n int) ([]byte, error) {
if c.peeked != nil {
return nil, fmt.Errorf("peek already called")
}
buf := make([]byte, n)
if _, err := io.ReadFull(c.Conn, buf); err != nil {
return nil, fmt.Errorf("peek %d bytes: %w", n, err)
}
c.peeked = buf
c.buf.Write(buf)
return buf, nil
}
// PeekAll reads up to n bytes, returning whatever is available.
// Unlike Peek, it does not require exactly n bytes.
func (c *peekConn) PeekAll(n int) ([]byte, error) {
if c.peeked != nil {
return nil, fmt.Errorf("peek already called")
}
buf := make([]byte, n)
nr, err := c.Conn.Read(buf)
if nr > 0 {
c.peeked = buf[:nr]
c.buf.Write(c.peeked)
}
if err != nil && nr == 0 {
return nil, fmt.Errorf("peek: %w", err)
}
return c.peeked, nil
}
// PeekMore extends the peeked buffer to at least n total bytes.
// The buffer is reset and refilled with the extended data.
// The returned slice is the internal peeked buffer; callers must not
// retain references from prior Peek/PeekMore calls after calling this.
func (c *peekConn) PeekMore(n int) ([]byte, error) {
if len(c.peeked) >= n {
return c.peeked[:n], nil
}
remaining := n - len(c.peeked)
extra := make([]byte, remaining)
if _, err := io.ReadFull(c.Conn, extra); err != nil {
return nil, fmt.Errorf("peek more %d bytes: %w", remaining, err)
}
// Pre-allocate to avoid reallocation detaching previously returned slices.
combined := make([]byte, 0, n)
combined = append(combined, c.peeked...)
combined = append(combined, extra...)
c.peeked = combined
c.buf.Reset()
c.buf.Write(c.peeked)
return c.peeked, nil
}
// Peeked returns the bytes that were peeked so far, or nil if Peek hasn't been called.
func (c *peekConn) Peeked() []byte {
return c.peeked
}
// Read returns buffered peek data first, then reads from the underlying connection.
func (c *peekConn) Read(p []byte) (int, error) {
if c.buf.Len() > 0 {
return c.buf.Read(p)
}
return c.Conn.Read(p)
}
// reader returns an io.Reader that replays buffered bytes then reads from conn.
func (c *peekConn) reader() io.Reader {
if c.buf.Len() > 0 {
return io.MultiReader(&c.buf, c.Conn)
}
return c.Conn
}

482
client/inspect/proxy.go Normal file
View File

@@ -0,0 +1,482 @@
package inspect
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"sync"
"time"
log "github.com/sirupsen/logrus"
)
// ErrBlocked is returned when a connection is denied by proxy policy.
var ErrBlocked = errors.New("connection blocked by proxy policy")
const (
// headerReadTimeout is the deadline for reading the initial protocol header.
// Prevents slow loris attacks where a client opens a connection but sends data slowly.
headerReadTimeout = 10 * time.Second
// idleTimeout is the deadline for idle connections between HTTP requests.
idleTimeout = 120 * time.Second
)
// Proxy is the inspection engine for traffic passing through a NetBird
// routing peer. It handles protocol detection, rule evaluation, MITM TLS
// decryption, ICAP delegation, and external proxy forwarding.
type Proxy struct {
config Config
rules *RuleEngine
certs *CertProvider
icap *ICAPClient
// envoy is nil unless mode is ModeEnvoy.
envoy *envoyManager
// dialer is the outbound dialer (with SO_MARK cleared on Linux).
dialer net.Dialer
log *log.Entry
// wgNetwork is the WG overlay prefix; dial targets inside it are blocked.
wgNetwork netip.Prefix
// localIPs reports the routing peer's own IPs; dial targets are blocked.
localIPs LocalIPChecker
// listener is the TPROXY/REDIRECT listener for kernel mode.
listener net.Listener
mu sync.RWMutex
ctx context.Context
cancel context.CancelFunc
}
// LocalIPChecker reports whether an IP belongs to the local machine.
type LocalIPChecker interface {
IsLocalIP(netip.Addr) bool
}
// New creates a transparent proxy with the given configuration.
func New(ctx context.Context, logger *log.Entry, config Config) (*Proxy, error) {
ctx, cancel := context.WithCancel(ctx)
p := &Proxy{
config: config,
rules: NewRuleEngine(logger, config.DefaultAction),
dialer: newOutboundDialer(),
log: logger,
wgNetwork: config.WGNetwork,
localIPs: config.LocalIPChecker,
ctx: ctx,
cancel: cancel,
}
p.rules.UpdateRules(config.Rules, config.DefaultAction)
// Initialize MITM certificate provider
if config.TLS != nil {
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
}
// Initialize ICAP client
if config.ICAP != nil {
p.icap = NewICAPClient(logger, config.ICAP)
}
// Start envoy sidecar if configured
if config.Mode == ModeEnvoy {
envoyLog := logger.WithField("sidecar", "envoy")
em, err := startEnvoy(ctx, envoyLog, config)
if err != nil {
cancel()
return nil, fmt.Errorf("start envoy sidecar: %w", err)
}
p.envoy = em
}
// Start TPROXY listener for kernel mode
if config.ListenAddr.IsValid() {
ln, err := newTPROXYListener(logger, config.ListenAddr, netip.Prefix{})
if err != nil {
cancel()
return nil, fmt.Errorf("start TPROXY listener on %s: %w", config.ListenAddr, err)
}
p.listener = ln
go p.acceptLoop(ln)
}
return p, nil
}
// HandleTCP is the entry point for TCP connections from the userspace forwarder.
// It determines the protocol (TLS or plaintext HTTP), evaluates rules,
// and either blocks, passes through, inspects, or forwards to an external proxy.
func (p *Proxy) HandleTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) error {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
p.mu.RLock()
mode := p.config.Mode
p.mu.RUnlock()
if mode == ModeExternal {
pconn := newPeekConn(clientConn)
return p.handleExternal(ctx, pconn, dst)
}
// Envoy and builtin modes both peek the protocol header for rule evaluation.
// Envoy mode forwards non-blocked traffic to envoy; builtin mode handles all locally.
// TLS blocks are handled by Go (instant close) since envoy can't cleanly RST a TLS connection.
// Built-in and envoy mode: peek 5 bytes (TLS record header size) to determine protocol.
// Set a read deadline to prevent slow loris attacks.
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
return fmt.Errorf("set read deadline: %w", err)
}
pconn := newPeekConn(clientConn)
header, err := pconn.Peek(5)
if err != nil {
return fmt.Errorf("peek protocol header: %w", err)
}
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
return fmt.Errorf("clear read deadline: %w", err)
}
if isTLSHandshake(header[0]) {
return p.handleTLS(ctx, pconn, dst, src)
}
if isHTTPMethod(header) {
return p.handlePlainHTTP(ctx, pconn, dst, src)
}
// Not TLS and not HTTP: evaluate rules with ProtoOther.
// If no rule explicitly allows "other", this falls through to the default action.
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
if action == ActionAllow {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial for passthrough: %w", err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote conn: %v", err)
}
}()
return relay(ctx, pconn, remote)
}
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
return ErrBlocked
}
// InspectTCP evaluates rules for a TCP connection and returns the result.
// Unlike HandleTCP, it can return early for allow decisions, letting the caller
// handle the relay (USP forwarder passthrough optimization).
//
// When InspectResult.PassthroughConn is non-nil, ownership transfers to the caller:
// the caller must close the connection and relay traffic. The engine does not close it.
//
// When PassthroughConn is nil, the engine handled everything internally
// (block, inspect/MITM, or plain HTTP inspection) and closed the connection.
func (p *Proxy) InspectTCP(ctx context.Context, clientConn net.Conn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
p.mu.RLock()
mode := p.config.Mode
envoy := p.envoy
p.mu.RUnlock()
// External mode: handle internally, engine owns the connection.
if mode == ModeExternal {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
pconn := newPeekConn(clientConn)
err := p.handleExternal(ctx, pconn, dst)
return InspectResult{Action: ActionAllow}, err
}
// Peek protocol header.
if err := clientConn.SetReadDeadline(time.Now().Add(headerReadTimeout)); err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("set read deadline: %w", err)
}
pconn := newPeekConn(clientConn)
header, err := pconn.Peek(5)
if err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("peek protocol header: %w", err)
}
if err := clientConn.SetReadDeadline(time.Time{}); err != nil {
clientConn.Close()
return InspectResult{}, fmt.Errorf("clear read deadline: %w", err)
}
// TLS: may return passthrough for allow.
if isTLSHandshake(header[0]) {
result, err := p.inspectTLS(ctx, pconn, dst, src)
if err != nil && result.PassthroughConn == nil {
clientConn.Close()
return result, err
}
// Envoy mode: forward allowed TLS to envoy instead of returning passthrough.
if result.PassthroughConn != nil && envoy != nil {
defer clientConn.Close()
envoyErr := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, envoyErr
}
return result, err
}
// Plain HTTP: in envoy mode, forward to envoy for L7 processing.
// In builtin mode, inspect per-request locally.
if isHTTPMethod(header) {
defer func() {
if err := clientConn.Close(); err != nil {
p.log.Debugf("close client conn: %v", err)
}
}()
if envoy != nil {
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, err
}
err := p.handlePlainHTTP(ctx, pconn, dst, src)
return InspectResult{Action: ActionInspect}, err
}
// Other protocol: evaluate rules.
action := p.rules.Evaluate(src.IP, "", dst.Addr(), dst.Port(), ProtoOther, "")
if action == ActionAllow {
// Envoy mode: forward to envoy.
if envoy != nil {
defer clientConn.Close()
err := p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
return InspectResult{Action: ActionAllow}, err
}
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
}
p.log.Debugf("block: non-HTTP/TLS to %s (action=%s, first bytes: %x)", dst, action, header)
clientConn.Close()
return InspectResult{Action: ActionBlock}, ErrBlocked
}
// HandleUDPPacket inspects a UDP packet for QUIC Initial packets.
// Returns the action to take: ActionAllow to continue normal forwarding,
// ActionBlock to drop the packet.
// Non-QUIC packets always return ActionAllow.
func (p *Proxy) HandleUDPPacket(data []byte, dst netip.AddrPort, src SourceInfo) Action {
if len(data) < 5 {
return ActionAllow
}
// Check for QUIC Long Header
if data[0]&0x80 == 0 {
return ActionAllow
}
sni, err := ExtractQUICSNI(data)
if err != nil {
// Can't parse QUIC, allow through (could be non-QUIC UDP)
p.log.Tracef("QUIC SNI extraction failed for %s: %v", dst, err)
return ActionAllow
}
if sni == "" {
return ActionAllow
}
action := p.rules.Evaluate(src.IP, sni, dst.Addr(), dst.Port(), ProtoH3, "")
if action == ActionBlock {
p.log.Debugf("block: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
return ActionBlock
}
// QUIC can't be MITMed, treat Inspect as Allow
if action == ActionInspect {
p.log.Debugf("allow: QUIC to %s (SNI=%s), MITM not supported for QUIC", dst, sni.PunycodeString())
} else {
p.log.Tracef("allow: QUIC to %s (SNI=%s)", dst, sni.PunycodeString())
}
return ActionAllow
}
// handlePlainHTTP handles plaintext HTTP connections.
func (p *Proxy) handlePlainHTTP(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial %s: %w", dst, err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote for %s: %v", dst, err)
}
}()
// For plaintext HTTP, always inspect (we can see the traffic)
return p.inspectHTTP(ctx, pconn, remote, dst, "", src, "http/1.1")
}
// UpdateConfig replaces the inspection engine configuration at runtime.
func (p *Proxy) UpdateConfig(config Config) {
p.log.Debugf("config update: mode=%s rules=%d default=%s has_tls=%v has_icap=%v",
config.Mode, len(config.Rules), config.DefaultAction, config.TLS != nil, config.ICAP != nil)
p.mu.Lock()
p.config = config
p.rules.UpdateRules(config.Rules, config.DefaultAction)
// Update MITM provider
if config.TLS != nil {
p.certs = NewCertProvider(config.TLS.CA, config.TLS.CAKey)
} else {
p.certs = nil
}
// Swap ICAP client under lock, close the old one outside to avoid blocking.
var oldICAP *ICAPClient
if config.ICAP != nil {
oldICAP = p.icap
p.icap = NewICAPClient(p.log, config.ICAP)
} else {
oldICAP = p.icap
p.icap = nil
}
// If switching away from envoy mode, clear and stop the old envoy.
var oldEnvoy *envoyManager
if config.Mode != ModeEnvoy && p.envoy != nil {
oldEnvoy = p.envoy
p.envoy = nil
}
envoy := p.envoy
p.mu.Unlock()
if oldICAP != nil {
oldICAP.Close()
}
if oldEnvoy != nil {
oldEnvoy.Stop()
}
// Reload envoy config if still in envoy mode.
if envoy != nil && config.Mode == ModeEnvoy {
if err := envoy.Reload(config); err != nil {
p.log.Errorf("inspect: envoy config reload: %v", err)
}
}
}
// Mode returns the current proxy operating mode.
func (p *Proxy) Mode() ProxyMode {
p.mu.RLock()
defer p.mu.RUnlock()
return p.config.Mode
}
// ListenPort returns the port to use for kernel-mode nftables REDIRECT.
// For builtin mode: the TPROXY listener port.
// For envoy mode: the envoy listener port (nftables redirects directly to envoy).
// Returns 0 if no listener is active.
func (p *Proxy) ListenPort() uint16 {
p.mu.RLock()
envoy := p.envoy
p.mu.RUnlock()
if envoy != nil {
return envoy.listenPort
}
if p.listener == nil {
return 0
}
tcpAddr, ok := p.listener.Addr().(*net.TCPAddr)
if !ok {
return 0
}
return uint16(tcpAddr.Port)
}
// Close shuts down the proxy and releases resources.
func (p *Proxy) Close() error {
p.cancel()
p.mu.Lock()
envoy := p.envoy
p.envoy = nil
icap := p.icap
p.icap = nil
p.mu.Unlock()
if envoy != nil {
envoy.Stop()
}
if p.listener != nil {
if err := p.listener.Close(); err != nil {
p.log.Debugf("close TPROXY listener: %v", err)
}
}
if icap != nil {
icap.Close()
}
return nil
}
// acceptLoop accepts connections from the redirected listener (kernel mode).
// Connections arrive via nftables REDIRECT; original destination is read from conntrack.
func (p *Proxy) acceptLoop(ln net.Listener) {
for {
conn, err := ln.Accept()
if err != nil {
if p.ctx.Err() != nil {
return
}
p.log.Debugf("accept error: %v", err)
continue
}
go func() {
// Read original destination from conntrack (SO_ORIGINAL_DST).
// nftables REDIRECT changes dst to the local WG IP:proxy_port,
// but conntrack preserves the real destination.
dstAddr, err := getOriginalDst(conn)
if err != nil {
p.log.Debugf("get original dst: %v", err)
if closeErr := conn.Close(); closeErr != nil {
p.log.Debugf("close conn: %v", closeErr)
}
return
}
p.log.Tracef("accepted: %s -> %s (original dst %s)",
conn.RemoteAddr(), conn.LocalAddr(), dstAddr)
srcAddr, err := netip.ParseAddrPort(conn.RemoteAddr().String())
if err != nil {
p.log.Debugf("parse source: %v", err)
if closeErr := conn.Close(); closeErr != nil {
p.log.Debugf("close conn: %v", closeErr)
}
return
}
src := SourceInfo{
IP: srcAddr.Addr().Unmap(),
}
if err := p.HandleTCP(p.ctx, conn, dstAddr, src); err != nil && !errors.Is(err, ErrBlocked) {
p.log.Debugf("connection to %s: %v", dstAddr, err)
}
}()
}
}

388
client/inspect/quic.go Normal file
View File

@@ -0,0 +1,388 @@
package inspect
import (
"crypto/aes"
"crypto/cipher"
"crypto/sha256"
"encoding/binary"
"fmt"
"io"
"golang.org/x/crypto/hkdf"
"github.com/netbirdio/netbird/shared/management/domain"
)
// QUIC version constants
const (
quicV1Version uint32 = 0x00000001
quicV2Version uint32 = 0x6b3343cf
)
// quicV1Salt is the initial salt for QUIC v1 (RFC 9001 Section 5.2).
var quicV1Salt = []byte{
0x38, 0x76, 0x2c, 0xf7, 0xf5, 0x59, 0x34, 0xb3,
0x4d, 0x17, 0x9a, 0xe6, 0xa4, 0xc8, 0x0c, 0xad,
0xcc, 0xbb, 0x7f, 0x0a,
}
// quicV2Salt is the initial salt for QUIC v2 (RFC 9369).
var quicV2Salt = []byte{
0x0d, 0xed, 0xe3, 0xde, 0xf7, 0x00, 0xa6, 0xdb,
0x81, 0x93, 0x81, 0xbe, 0x6e, 0x26, 0x9d, 0xcb,
0xf9, 0xbd, 0x2e, 0xd9,
}
// ExtractQUICSNI extracts the SNI from a QUIC Initial packet.
// The Initial packet's encryption uses well-known keys derived from the
// Destination Connection ID, so any observer can decrypt it (by design).
func ExtractQUICSNI(data []byte) (domain.Domain, error) {
if len(data) < 5 {
return "", fmt.Errorf("packet too short")
}
// Check for QUIC Long Header (form bit set)
if data[0]&0x80 == 0 {
return "", fmt.Errorf("not a QUIC long header packet")
}
// Version
version := binary.BigEndian.Uint32(data[1:5])
var salt []byte
var initialLabel, keyLabel, ivLabel, hpLabel string
switch version {
case quicV1Version:
salt = quicV1Salt
initialLabel = "client in"
keyLabel = "quic key"
ivLabel = "quic iv"
hpLabel = "quic hp"
case quicV2Version:
salt = quicV2Salt
initialLabel = "client in"
keyLabel = "quicv2 key"
ivLabel = "quicv2 iv"
hpLabel = "quicv2 hp"
default:
return "", fmt.Errorf("unsupported QUIC version: 0x%08x", version)
}
// Parse Long Header
if len(data) < 6 {
return "", fmt.Errorf("packet too short for DCID length")
}
dcidLen := int(data[5])
if len(data) < 6+dcidLen+1 {
return "", fmt.Errorf("packet too short for DCID")
}
dcid := data[6 : 6+dcidLen]
scidLenOff := 6 + dcidLen
scidLen := int(data[scidLenOff])
tokenLenOff := scidLenOff + 1 + scidLen
if tokenLenOff >= len(data) {
return "", fmt.Errorf("packet too short for token length")
}
// Token length is a variable-length integer
tokenLen, tokenLenSize, err := readVarInt(data[tokenLenOff:])
if err != nil {
return "", fmt.Errorf("read token length: %w", err)
}
payloadLenOff := tokenLenOff + tokenLenSize + int(tokenLen)
if payloadLenOff >= len(data) {
return "", fmt.Errorf("packet too short for payload length")
}
// Payload length is a variable-length integer
payloadLen, payloadLenSize, err := readVarInt(data[payloadLenOff:])
if err != nil {
return "", fmt.Errorf("read payload length: %w", err)
}
pnOffset := payloadLenOff + payloadLenSize
if pnOffset+4 > len(data) {
return "", fmt.Errorf("packet too short for packet number")
}
// Derive initial keys
clientKey, clientIV, clientHP, err := deriveInitialKeys(dcid, salt, initialLabel, keyLabel, ivLabel, hpLabel)
if err != nil {
return "", fmt.Errorf("derive initial keys: %w", err)
}
// Remove header protection
sampleOffset := pnOffset + 4 // sample starts 4 bytes after pn offset
if sampleOffset+16 > len(data) {
return "", fmt.Errorf("packet too short for HP sample")
}
sample := data[sampleOffset : sampleOffset+16]
hpBlock, err := aes.NewCipher(clientHP)
if err != nil {
return "", fmt.Errorf("create HP cipher: %w", err)
}
mask := make([]byte, 16)
hpBlock.Encrypt(mask, sample)
// Unmask header byte
header := make([]byte, len(data))
copy(header, data)
header[0] ^= mask[0] & 0x0f // Long header: low 4 bits
// Determine packet number length
pnLen := int(header[0]&0x03) + 1
// Unmask packet number
for i := 0; i < pnLen; i++ {
header[pnOffset+i] ^= mask[1+i]
}
// Reconstruct packet number
var pn uint32
for i := 0; i < pnLen; i++ {
pn = (pn << 8) | uint32(header[pnOffset+i])
}
// Build nonce
nonce := make([]byte, len(clientIV))
copy(nonce, clientIV)
for i := 0; i < 4; i++ {
nonce[len(nonce)-1-i] ^= byte(pn >> (8 * i))
}
// Decrypt payload
block, err := aes.NewCipher(clientKey)
if err != nil {
return "", fmt.Errorf("create AES cipher: %w", err)
}
aead, err := cipher.NewGCM(block)
if err != nil {
return "", fmt.Errorf("create AEAD: %w", err)
}
encryptedPayload := header[pnOffset+pnLen : pnOffset+int(payloadLen)]
aad := header[:pnOffset+pnLen]
plaintext, err := aead.Open(nil, nonce, encryptedPayload, aad)
if err != nil {
return "", fmt.Errorf("decrypt QUIC payload: %w", err)
}
// Parse CRYPTO frames to extract ClientHello
clientHello, err := extractCryptoFrames(plaintext)
if err != nil {
return "", fmt.Errorf("extract CRYPTO frames: %w", err)
}
info, err := parseHelloBody(clientHello)
return info.SNI, err
}
// deriveInitialKeys derives the client's initial encryption keys from the DCID.
func deriveInitialKeys(dcid, salt []byte, initialLabel, keyLabel, ivLabel, hpLabel string) (key, iv, hp []byte, err error) {
// initial_secret = HKDF-Extract(salt, DCID)
initialSecret := hkdf.Extract(sha256.New, dcid, salt)
// client_initial_secret = HKDF-Expand-Label(initial_secret, initialLabel, "", 32)
clientSecret, err := hkdfExpandLabel(initialSecret, initialLabel, nil, 32)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive client secret: %w", err)
}
// client_key = HKDF-Expand-Label(client_secret, keyLabel, "", 16)
key, err = hkdfExpandLabel(clientSecret, keyLabel, nil, 16)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive key: %w", err)
}
// client_iv = HKDF-Expand-Label(client_secret, ivLabel, "", 12)
iv, err = hkdfExpandLabel(clientSecret, ivLabel, nil, 12)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive IV: %w", err)
}
// client_hp = HKDF-Expand-Label(client_secret, hpLabel, "", 16)
hp, err = hkdfExpandLabel(clientSecret, hpLabel, nil, 16)
if err != nil {
return nil, nil, nil, fmt.Errorf("derive HP key: %w", err)
}
return key, iv, hp, nil
}
// hkdfExpandLabel implements TLS 1.3 HKDF-Expand-Label.
func hkdfExpandLabel(secret []byte, label string, context []byte, length int) ([]byte, error) {
// HkdfLabel = struct {
// uint16 length;
// opaque label<7..255> = "tls13 " + Label;
// opaque context<0..255> = Context;
// }
fullLabel := "tls13 " + label
hkdfLabel := make([]byte, 2+1+len(fullLabel)+1+len(context))
binary.BigEndian.PutUint16(hkdfLabel[0:2], uint16(length))
hkdfLabel[2] = byte(len(fullLabel))
copy(hkdfLabel[3:], fullLabel)
hkdfLabel[3+len(fullLabel)] = byte(len(context))
if len(context) > 0 {
copy(hkdfLabel[4+len(fullLabel):], context)
}
expander := hkdf.Expand(sha256.New, secret, hkdfLabel)
out := make([]byte, length)
if _, err := io.ReadFull(expander, out); err != nil {
return nil, err
}
return out, nil
}
// maxCryptoFrameSize limits total CRYPTO frame data to prevent memory exhaustion.
const maxCryptoFrameSize = 64 * 1024
// extractCryptoFrames reassembles CRYPTO frame data from QUIC frames.
func extractCryptoFrames(frames []byte) ([]byte, error) {
var result []byte
pos := 0
for pos < len(frames) {
frameType := frames[pos]
switch {
case frameType == 0x00:
// PADDING frame
pos++
case frameType == 0x06:
// CRYPTO frame
pos++
offset, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read crypto offset: %w", err)
}
pos += n
_ = offset // We assume ordered, offset 0 for Initial
dataLen, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read crypto data length: %w", err)
}
pos += n
end := pos + int(dataLen)
if end > len(frames) {
return nil, fmt.Errorf("CRYPTO frame data truncated")
}
result = append(result, frames[pos:end]...)
if len(result) > maxCryptoFrameSize {
return nil, fmt.Errorf("CRYPTO frame data exceeds %d bytes", maxCryptoFrameSize)
}
pos = end
case frameType == 0x01:
// PING frame
pos++
case frameType == 0x02 || frameType == 0x03:
// ACK frame - skip
pos++
// Largest Acknowledged
_, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK: %w", err)
}
pos += n
// ACK Delay
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK delay: %w", err)
}
pos += n
// ACK Range Count
rangeCount, n, err := readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK range count: %w", err)
}
pos += n
// First ACK Range
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read first ACK range: %w", err)
}
pos += n
// Additional ranges
for i := uint64(0); i < rangeCount; i++ {
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK gap: %w", err)
}
pos += n
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ACK range: %w", err)
}
pos += n
}
// ECN counts for type 0x03
if frameType == 0x03 {
for range 3 {
_, n, err = readVarInt(frames[pos:])
if err != nil {
return nil, fmt.Errorf("read ECN count: %w", err)
}
pos += n
}
}
default:
// Unknown frame type, stop parsing
if len(result) > 0 {
return result, nil
}
return nil, fmt.Errorf("unknown QUIC frame type: 0x%02x at offset %d", frameType, pos)
}
}
if len(result) == 0 {
return nil, fmt.Errorf("no CRYPTO frames found")
}
return result, nil
}
// readVarInt reads a QUIC variable-length integer.
// Returns (value, bytes consumed, error).
func readVarInt(data []byte) (uint64, int, error) {
if len(data) == 0 {
return 0, 0, fmt.Errorf("empty data for varint")
}
prefix := data[0] >> 6
length := 1 << prefix
if len(data) < length {
return 0, 0, fmt.Errorf("varint truncated: need %d, have %d", length, len(data))
}
var val uint64
switch length {
case 1:
val = uint64(data[0] & 0x3f)
case 2:
val = uint64(binary.BigEndian.Uint16(data[:2])) & 0x3fff
case 4:
val = uint64(binary.BigEndian.Uint32(data[:4])) & 0x3fffffff
case 8:
val = binary.BigEndian.Uint64(data[:8]) & 0x3fffffffffffffff
}
return val, length, nil
}

View File

@@ -0,0 +1,99 @@
package inspect
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestReadVarInt(t *testing.T) {
tests := []struct {
name string
data []byte
want uint64
n int
}{
{
name: "1 byte value",
data: []byte{0x25},
want: 37,
n: 1,
},
{
name: "2 byte value",
data: []byte{0x7b, 0xbd},
want: 15293,
n: 2,
},
{
name: "4 byte value",
data: []byte{0x9d, 0x7f, 0x3e, 0x7d},
want: 494878333,
n: 4,
},
{
name: "zero",
data: []byte{0x00},
want: 0,
n: 1,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
val, n, err := readVarInt(tt.data)
require.NoError(t, err)
assert.Equal(t, tt.want, val)
assert.Equal(t, tt.n, n)
})
}
}
func TestReadVarInt_Empty(t *testing.T) {
_, _, err := readVarInt(nil)
require.Error(t, err)
}
func TestReadVarInt_Truncated(t *testing.T) {
// 2-byte prefix but only 1 byte
_, _, err := readVarInt([]byte{0x40})
require.Error(t, err)
}
func TestExtractQUICSNI_NotLongHeader(t *testing.T) {
// Short header packet (form bit not set)
data := make([]byte, 100)
data[0] = 0x40 // short header
_, err := ExtractQUICSNI(data)
require.Error(t, err)
assert.Contains(t, err.Error(), "not a QUIC long header")
}
func TestExtractQUICSNI_UnsupportedVersion(t *testing.T) {
data := make([]byte, 100)
data[0] = 0xC0 // long header
// Version 0xdeadbeef
data[1] = 0xde
data[2] = 0xad
data[3] = 0xbe
data[4] = 0xef
_, err := ExtractQUICSNI(data)
require.Error(t, err)
assert.Contains(t, err.Error(), "unsupported QUIC version")
}
func TestExtractQUICSNI_TooShort(t *testing.T) {
_, err := ExtractQUICSNI([]byte{0xC0, 0x00})
require.Error(t, err)
}
func TestHkdfExpandLabel(t *testing.T) {
// Smoke test: ensure it returns the right length and doesn't error
secret := make([]byte, 32)
result, err := hkdfExpandLabel(secret, "quic key", nil, 16)
require.NoError(t, err)
assert.Len(t, result, 16)
}

253
client/inspect/rules.go Normal file
View File

@@ -0,0 +1,253 @@
package inspect
import (
"net/netip"
"slices"
"sort"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
// RuleEngine evaluates proxy rules against connection metadata.
// It is safe for concurrent use.
type RuleEngine struct {
mu sync.RWMutex
rules []Rule
// defaultAction applies when no rule matches.
defaultAction Action
log *log.Entry
}
// NewRuleEngine creates a rule engine with the given default action.
func NewRuleEngine(logger *log.Entry, defaultAction Action) *RuleEngine {
return &RuleEngine{
defaultAction: defaultAction,
log: logger,
}
}
// UpdateRules replaces the rule set and default action. Rules are sorted by priority.
func (e *RuleEngine) UpdateRules(rules []Rule, defaultAction Action) {
sorted := make([]Rule, len(rules))
copy(sorted, rules)
sort.Slice(sorted, func(i, j int) bool {
return sorted[i].Priority < sorted[j].Priority
})
e.mu.Lock()
e.rules = sorted
e.defaultAction = defaultAction
e.mu.Unlock()
}
// EvalResult holds the outcome of a rule evaluation.
type EvalResult struct {
Action Action
RuleID id.RuleID
}
// Evaluate determines the action for a connection based on the rule set.
// Pass empty path for connection-level evaluation (TLS/SNI), non-empty for request-level (HTTP).
func (e *RuleEngine) Evaluate(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) Action {
r := e.EvaluateWithResult(src, dstDomain, dstAddr, dstPort, proto, path)
return r.Action
}
// EvaluateWithResult is like Evaluate but also returns the matched rule ID.
func (e *RuleEngine) EvaluateWithResult(src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) EvalResult {
e.mu.RLock()
defer e.mu.RUnlock()
for i := range e.rules {
rule := &e.rules[i]
if e.ruleMatches(rule, src, dstDomain, dstAddr, dstPort, proto, path) {
e.log.Tracef("rule %s matched: action=%s src=%s domain=%s dst=%s:%d proto=%s path=%s",
rule.ID, rule.Action, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
return EvalResult{Action: rule.Action, RuleID: rule.ID}
}
}
e.log.Tracef("no rule matched, default=%s: src=%s domain=%s dst=%s:%d proto=%s path=%s",
e.defaultAction, src, dstDomain.SafeString(), dstAddr, dstPort, proto, path)
return EvalResult{Action: e.defaultAction}
}
// HasPathRulesForDomain returns true if any rule matching the domain has non-empty Paths.
// Used to force MITM inspection when path-level rules exist (paths are only visible after decryption).
func (e *RuleEngine) HasPathRulesForDomain(dstDomain domain.Domain) bool {
e.mu.RLock()
defer e.mu.RUnlock()
for i := range e.rules {
if len(e.rules[i].Paths) > 0 && e.matchDomain(&e.rules[i], dstDomain) {
return true
}
}
return false
}
// ruleMatches checks whether all non-empty fields of a rule match.
// Empty fields are treated as "match any".
// All specified fields must match (AND logic).
func (e *RuleEngine) ruleMatches(rule *Rule, src netip.Addr, dstDomain domain.Domain, dstAddr netip.Addr, dstPort uint16, proto ProtoType, path string) bool {
if !e.matchSource(rule, src) {
return false
}
if !e.matchDomain(rule, dstDomain) {
return false
}
if !e.matchNetwork(rule, dstAddr) {
return false
}
if !e.matchPort(rule, dstPort) {
return false
}
if !e.matchProtocol(rule, proto) {
return false
}
if !e.matchPaths(rule, path) {
return false
}
return true
}
// matchSource returns true if src matches any of the rule's source CIDRs,
// or if no source CIDRs are specified (match any).
func (e *RuleEngine) matchSource(rule *Rule, src netip.Addr) bool {
if len(rule.Sources) == 0 {
return true
}
for _, prefix := range rule.Sources {
if prefix.Contains(src) {
return true
}
}
return false
}
// matchDomain returns true if dstDomain matches any of the rule's domain patterns,
// or if no domain patterns are specified (match any).
func (e *RuleEngine) matchDomain(rule *Rule, dstDomain domain.Domain) bool {
if len(rule.Domains) == 0 {
return true
}
// If we have domain rules but no domain to match against (e.g., raw IP connection),
// the domain condition does not match.
if dstDomain == "" {
return false
}
for _, pattern := range rule.Domains {
if MatchDomain(pattern, dstDomain) {
return true
}
}
return false
}
// matchNetwork returns true if dstAddr is within any of the rule's destination CIDRs,
// or if no destination CIDRs are specified (match any).
func (e *RuleEngine) matchNetwork(rule *Rule, dstAddr netip.Addr) bool {
if len(rule.Networks) == 0 {
return true
}
for _, prefix := range rule.Networks {
if prefix.Contains(dstAddr) {
return true
}
}
return false
}
// matchProtocol returns true if proto matches any of the rule's protocols,
// or if no protocols are specified (match any).
func (e *RuleEngine) matchProtocol(rule *Rule, proto ProtoType) bool {
if len(rule.Protocols) == 0 {
return true
}
for _, p := range rule.Protocols {
if p == proto {
return true
}
}
return false
}
// matchPort returns true if dstPort matches any of the rule's destination ports,
// or if no ports are specified (match any).
func (e *RuleEngine) matchPort(rule *Rule, dstPort uint16) bool {
if len(rule.Ports) == 0 {
return true
}
return slices.Contains(rule.Ports, dstPort)
}
// matchPaths returns true if path matches any of the rule's path patterns,
// or if no paths are specified (match any). Empty path (connection-level eval) matches all.
func (e *RuleEngine) matchPaths(rule *Rule, path string) bool {
if len(rule.Paths) == 0 {
return true
}
// Connection-level (path=""): rules with paths don't match at connection level.
// HasPathRulesForDomain forces the connection to inspect, so paths are
// checked per-request once the HTTP request is visible.
if path == "" {
return false
}
for _, pattern := range rule.Paths {
if matchPath(pattern, path) {
return true
}
}
return false
}
// matchPath checks if a URL path matches a pattern.
// Supports: exact ("/login"), prefix with wildcard ("/api/*"),
// and contains ("*/admin/*"). A bare "*" matches everything.
func matchPath(pattern, path string) bool {
if pattern == "*" {
return true
}
hasLeadingStar := strings.HasPrefix(pattern, "*")
hasTrailingStar := strings.HasSuffix(pattern, "*")
switch {
case hasLeadingStar && hasTrailingStar:
// */admin/* = contains
middle := strings.Trim(pattern, "*")
return strings.Contains(path, middle)
case hasTrailingStar:
// /api/* = prefix
prefix := strings.TrimSuffix(pattern, "*")
return strings.HasPrefix(path, prefix)
case hasLeadingStar:
// *.json = suffix
suffix := strings.TrimPrefix(pattern, "*")
return strings.HasSuffix(path, suffix)
default:
// exact
return path == pattern
}
}

View File

@@ -0,0 +1,338 @@
package inspect
import (
"net/netip"
"testing"
log "github.com/sirupsen/logrus"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
)
func testLogger() *log.Entry {
return log.WithField("test", true)
}
func mustDomain(t *testing.T, s string) domain.Domain {
t.Helper()
d, err := domain.FromString(s)
require.NoError(t, err)
return d
}
func TestRuleEngine_Evaluate(t *testing.T) {
tests := []struct {
name string
rules []Rule
defaultAction Action
src netip.Addr
dstDomain domain.Domain
dstAddr netip.Addr
dstPort uint16
want Action
}{
{
name: "no rules returns default allow",
defaultAction: ActionAllow,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "no rules returns default block",
defaultAction: ActionBlock,
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain exact match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "malware.example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "malware.example.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard match blocks",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "phishing.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "domain wildcard does not match base",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "case insensitive domain match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "Example.COM")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "EXAMPLE.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionBlock,
},
{
name: "source CIDR match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.50"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "source CIDR no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.5"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
{
name: "destination network match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Networks: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/8")},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("192.168.1.1"),
dstAddr: netip.MustParseAddr("10.50.0.1"),
dstPort: 80,
want: ActionInspect,
},
{
name: "port match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionInspect,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionInspect,
},
{
name: "port no match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Ports: []uint16{443, 8443},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 22,
want: ActionAllow,
},
{
name: "priority ordering first match wins",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("allow-internal"),
Domains: []domain.Domain{mustDomain(t, "*.internal.corp")},
Action: ActionAllow,
Priority: 1,
},
{
ID: id.RuleID("inspect-all"),
Action: ActionInspect,
Priority: 10,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: mustDomain(t, "api.internal.corp"),
dstAddr: netip.MustParseAddr("10.1.0.5"),
dstPort: 443,
want: ActionAllow,
},
{
name: "all fields must match (AND logic)",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Sources: []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
Domains: []domain.Domain{mustDomain(t, "*.evil.com")},
Ports: []uint16{443},
Action: ActionBlock,
},
},
// Source matches, domain matches, but port doesn't
src: netip.MustParseAddr("192.168.1.10"),
dstDomain: mustDomain(t, "phish.evil.com"),
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 8080,
want: ActionAllow,
},
{
name: "empty domain with domain rule does not match",
defaultAction: ActionAllow,
rules: []Rule{
{
ID: id.RuleID("r1"),
Domains: []domain.Domain{mustDomain(t, "example.com")},
Action: ActionBlock,
},
},
src: netip.MustParseAddr("10.0.0.1"),
dstDomain: "", // raw IP connection, no SNI
dstAddr: netip.MustParseAddr("1.2.3.4"),
dstPort: 443,
want: ActionAllow,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
engine := NewRuleEngine(testLogger(), tt.defaultAction)
engine.UpdateRules(tt.rules, tt.defaultAction)
got := engine.Evaluate(tt.src, tt.dstDomain, tt.dstAddr, tt.dstPort, "", "")
assert.Equal(t, tt.want, got)
})
}
}
func TestRuleEngine_ProtocolMatching(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-websocket",
Protocols: []ProtoType{ProtoWebSocket},
Action: ActionBlock,
Priority: 1,
},
{
ID: "inspect-h2",
Protocols: []ProtoType{ProtoH2},
Action: ActionInspect,
Priority: 2,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
// WebSocket: blocked by rule
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
// HTTP/2: inspected by rule
assert.Equal(t, ActionInspect, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
// Plain HTTP: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 80, ProtoHTTP, ""))
// HTTPS: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
// QUIC/H3: no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, ProtoH3, ""))
// Empty protocol (unknown): no protocol rule matches, default allow
assert.Equal(t, ActionAllow, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_EmptyProtocolsMatchAll(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{
ID: "block-all-protos",
Action: ActionBlock,
// No Protocols field = match all protocols
Priority: 1,
},
}, ActionAllow)
src := netip.MustParseAddr("10.0.0.1")
dst := netip.MustParseAddr("1.2.3.4")
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTP, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoHTTPS, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoWebSocket, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, ProtoH2, ""))
assert.Equal(t, ActionBlock, engine.Evaluate(src, "", dst, 443, "", ""))
}
func TestRuleEngine_UpdateRulesSortsByPriority(t *testing.T) {
engine := NewRuleEngine(testLogger(), ActionAllow)
engine.UpdateRules([]Rule{
{ID: "c", Priority: 30, Action: ActionBlock},
{ID: "a", Priority: 10, Action: ActionInspect},
{ID: "b", Priority: 20, Action: ActionAllow},
}, ActionAllow)
engine.mu.RLock()
defer engine.mu.RUnlock()
require.Len(t, engine.rules, 3)
assert.Equal(t, id.RuleID("a"), engine.rules[0].ID)
assert.Equal(t, id.RuleID("b"), engine.rules[1].ID)
assert.Equal(t, id.RuleID("c"), engine.rules[2].ID)
}

287
client/inspect/sni.go Normal file
View File

@@ -0,0 +1,287 @@
package inspect
import (
"encoding/binary"
"fmt"
"io"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
recordTypeHandshake = 0x16
handshakeTypeClientHello = 0x01
extensionTypeSNI = 0x0000
extensionTypeALPN = 0x0010
sniTypeHostName = 0x00
// maxClientHelloSize is the maximum ClientHello size we'll read.
// Real-world ClientHellos are typically under 1KB but can reach ~16KB with
// many extensions (post-quantum key shares, etc.).
maxClientHelloSize = 16384
)
// ClientHelloInfo holds data extracted from a TLS ClientHello.
type ClientHelloInfo struct {
SNI domain.Domain
ALPN []string
}
// isTLSHandshake reports whether the first byte indicates a TLS handshake record.
func isTLSHandshake(b byte) bool {
return b == recordTypeHandshake
}
// httpMethods lists the first bytes of valid HTTP method tokens.
var httpMethods = [][]byte{
[]byte("GET "),
[]byte("POST"),
[]byte("PUT "),
[]byte("DELE"),
[]byte("HEAD"),
[]byte("OPTI"),
[]byte("PATC"),
[]byte("CONN"),
[]byte("TRAC"),
}
// isHTTPMethod reports whether the peeked bytes look like the start of an HTTP request.
func isHTTPMethod(b []byte) bool {
if len(b) < 4 {
return false
}
for _, m := range httpMethods {
if b[0] == m[0] && b[1] == m[1] && b[2] == m[2] && b[3] == m[3] {
return true
}
}
return false
}
// parseClientHello reads a TLS ClientHello from r and returns SNI and ALPN.
func parseClientHello(r io.Reader) (ClientHelloInfo, error) {
// TLS record header: type(1) + version(2) + length(2)
var recordHeader [5]byte
if _, err := io.ReadFull(r, recordHeader[:]); err != nil {
return ClientHelloInfo{}, fmt.Errorf("read TLS record header: %w", err)
}
if recordHeader[0] != recordTypeHandshake {
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", recordHeader[0])
}
recordLen := int(binary.BigEndian.Uint16(recordHeader[3:5]))
if recordLen < 4 || recordLen > maxClientHelloSize {
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
}
// Read the full handshake message
msg := make([]byte, recordLen)
if _, err := io.ReadFull(r, msg); err != nil {
return ClientHelloInfo{}, fmt.Errorf("read handshake message: %w", err)
}
return parseClientHelloMsg(msg)
}
// extractSNI reads a TLS ClientHello from r and returns the SNI hostname.
// Returns empty domain if no SNI extension is present.
func extractSNI(r io.Reader) (domain.Domain, error) {
info, err := parseClientHello(r)
return info.SNI, err
}
// extractSNIFromBytes parses SNI from raw bytes that start with the TLS record header.
func extractSNIFromBytes(data []byte) (domain.Domain, error) {
info, err := parseClientHelloFromBytes(data)
return info.SNI, err
}
// parseClientHelloFromBytes parses a ClientHello from raw bytes starting with the TLS record header.
func parseClientHelloFromBytes(data []byte) (ClientHelloInfo, error) {
if len(data) < 5 {
return ClientHelloInfo{}, fmt.Errorf("data too short for TLS record header")
}
if data[0] != recordTypeHandshake {
return ClientHelloInfo{}, fmt.Errorf("not a TLS handshake record (type=%d)", data[0])
}
recordLen := int(binary.BigEndian.Uint16(data[3:5]))
if recordLen < 4 {
return ClientHelloInfo{}, fmt.Errorf("invalid TLS record length: %d", recordLen)
}
end := 5 + recordLen
if end > len(data) {
return ClientHelloInfo{}, fmt.Errorf("TLS record truncated: need %d, have %d", end, len(data))
}
return parseClientHelloMsg(data[5:end])
}
// parseClientHelloMsg extracts SNI and ALPN from a raw ClientHello handshake message.
// msg starts at the handshake type byte.
func parseClientHelloMsg(msg []byte) (ClientHelloInfo, error) {
if len(msg) < 4 {
return ClientHelloInfo{}, fmt.Errorf("handshake message too short")
}
if msg[0] != handshakeTypeClientHello {
return ClientHelloInfo{}, fmt.Errorf("not a ClientHello (type=%d)", msg[0])
}
// Handshake header: type(1) + length(3)
helloLen := int(msg[1])<<16 | int(msg[2])<<8 | int(msg[3])
if helloLen+4 > len(msg) {
return ClientHelloInfo{}, fmt.Errorf("ClientHello truncated")
}
hello := msg[4 : 4+helloLen]
return parseHelloBody(hello)
}
// parseHelloBody parses the ClientHello body (after handshake header)
// and extracts SNI and ALPN.
func parseHelloBody(hello []byte) (ClientHelloInfo, error) {
// ClientHello structure:
// version(2) + random(32) + session_id_len(1) + session_id(var)
// + cipher_suites_len(2) + cipher_suites(var)
// + compression_len(1) + compression(var)
// + extensions_len(2) + extensions(var)
var info ClientHelloInfo
if len(hello) < 35 {
return info, fmt.Errorf("ClientHello body too short")
}
pos := 2 + 32 // skip version + random
// Skip session ID
if pos >= len(hello) {
return info, fmt.Errorf("ClientHello truncated at session ID")
}
sessionIDLen := int(hello[pos])
pos += 1 + sessionIDLen
// Skip cipher suites
if pos+2 > len(hello) {
return info, fmt.Errorf("ClientHello truncated at cipher suites")
}
cipherLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
pos += 2 + cipherLen
// Skip compression methods
if pos >= len(hello) {
return info, fmt.Errorf("ClientHello truncated at compression")
}
compLen := int(hello[pos])
pos += 1 + compLen
// Extensions
if pos+2 > len(hello) {
return info, nil
}
extLen := int(binary.BigEndian.Uint16(hello[pos : pos+2]))
pos += 2
extEnd := pos + extLen
if extEnd > len(hello) {
return info, fmt.Errorf("extensions block truncated")
}
// Walk extensions looking for SNI and ALPN
for pos+4 <= extEnd {
extType := binary.BigEndian.Uint16(hello[pos : pos+2])
extDataLen := int(binary.BigEndian.Uint16(hello[pos+2 : pos+4]))
pos += 4
if pos+extDataLen > extEnd {
return info, fmt.Errorf("extension data truncated")
}
switch extType {
case extensionTypeSNI:
sni, err := parseSNIExtension(hello[pos : pos+extDataLen])
if err != nil {
return info, err
}
info.SNI = sni
case extensionTypeALPN:
info.ALPN = parseALPNExtension(hello[pos : pos+extDataLen])
}
pos += extDataLen
}
return info, nil
}
// parseALPNExtension parses the ALPN extension data and returns protocol names.
// ALPN extension: list_length(2) + entries (each: len(1) + protocol_name(var))
func parseALPNExtension(data []byte) []string {
if len(data) < 2 {
return nil
}
listLen := int(binary.BigEndian.Uint16(data[0:2]))
if listLen+2 > len(data) {
return nil
}
var protocols []string
pos := 2
end := 2 + listLen
for pos < end {
if pos >= len(data) {
break
}
nameLen := int(data[pos])
pos++
if pos+nameLen > end {
break
}
protocols = append(protocols, string(data[pos:pos+nameLen]))
pos += nameLen
}
return protocols
}
// parseSNIExtension parses the SNI extension data and returns the hostname.
func parseSNIExtension(data []byte) (domain.Domain, error) {
// SNI extension: list_length(2) + entries
if len(data) < 2 {
return "", fmt.Errorf("SNI extension too short")
}
listLen := int(binary.BigEndian.Uint16(data[0:2]))
if listLen+2 > len(data) {
return "", fmt.Errorf("SNI list truncated")
}
pos := 2
end := 2 + listLen
for pos+3 <= end {
nameType := data[pos]
nameLen := int(binary.BigEndian.Uint16(data[pos+1 : pos+3]))
pos += 3
if pos+nameLen > end {
return "", fmt.Errorf("SNI name truncated")
}
if nameType == sniTypeHostName {
hostname := string(data[pos : pos+nameLen])
return domain.FromString(hostname)
}
pos += nameLen
}
return "", nil
}

109
client/inspect/sni_test.go Normal file
View File

@@ -0,0 +1,109 @@
package inspect
import (
"bytes"
"crypto/tls"
"net"
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
)
func TestExtractSNI(t *testing.T) {
tests := []struct {
name string
sni string
wantSNI string
wantErr bool
}{
{
name: "standard domain",
sni: "example.com",
wantSNI: "example.com",
},
{
name: "subdomain",
sni: "api.staging.example.com",
wantSNI: "api.staging.example.com",
},
{
name: "mixed case normalized to lowercase",
sni: "Example.COM",
wantSNI: "example.com",
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
clientHello := buildClientHello(t, tt.sni)
sni, err := extractSNI(bytes.NewReader(clientHello))
if tt.wantErr {
require.Error(t, err)
return
}
require.NoError(t, err)
assert.Equal(t, tt.wantSNI, sni.PunycodeString())
})
}
}
func TestExtractSNI_NotTLS(t *testing.T) {
// HTTP request instead of TLS
data := []byte("GET / HTTP/1.1\r\nHost: example.com\r\n\r\n")
_, err := extractSNI(bytes.NewReader(data))
require.Error(t, err)
assert.Contains(t, err.Error(), "not a TLS handshake")
}
func TestExtractSNI_Truncated(t *testing.T) {
// Just the record header, no body
data := []byte{0x16, 0x03, 0x01, 0x00, 0x05}
_, err := extractSNI(bytes.NewReader(data))
require.Error(t, err)
}
func TestExtractSNIFromBytes(t *testing.T) {
clientHello := buildClientHello(t, "test.example.com")
sni, err := extractSNIFromBytes(clientHello)
require.NoError(t, err)
assert.Equal(t, "test.example.com", sni.PunycodeString())
}
// buildClientHello generates a real TLS ClientHello with the given SNI.
func buildClientHello(t *testing.T, serverName string) []byte {
t.Helper()
// Use a pipe to capture the ClientHello bytes
clientConn, serverConn := net.Pipe()
done := make(chan []byte, 1)
go func() {
buf := make([]byte, 4096)
n, _ := serverConn.Read(buf)
done <- buf[:n]
serverConn.Close()
}()
tlsConn := tls.Client(clientConn, &tls.Config{
ServerName: serverName,
InsecureSkipVerify: true,
})
// Trigger the handshake (will fail since server isn't TLS, but we capture the ClientHello)
go func() {
_ = tlsConn.Handshake()
tlsConn.Close()
}()
clientHello := <-done
clientConn.Close()
require.True(t, len(clientHello) > 5, "ClientHello too short")
require.Equal(t, byte(0x16), clientHello[0], "not a TLS handshake record")
return clientHello
}

287
client/inspect/tls.go Normal file
View File

@@ -0,0 +1,287 @@
package inspect
import (
"context"
"crypto/tls"
"fmt"
"io"
"net"
"net/netip"
"github.com/netbirdio/netbird/shared/management/domain"
)
// handleTLS processes a TLS connection for the kernel-mode path: extracts SNI,
// evaluates rules, and handles the connection internally.
// In envoy mode, allowed connections are forwarded to envoy instead of direct relay.
func (p *Proxy) handleTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) error {
result, err := p.inspectTLS(ctx, pconn, dst, src)
if err != nil {
return err
}
if result.PassthroughConn != nil {
p.mu.RLock()
envoy := p.envoy
p.mu.RUnlock()
if envoy != nil {
return p.forwardToEnvoy(ctx, pconn, dst, src, envoy)
}
return p.tlsPassthrough(ctx, pconn, dst, "")
}
return nil
}
// inspectTLS extracts SNI, evaluates rules, and returns the result.
// For ActionAllow: returns the peekConn as PassthroughConn (caller relays).
// For ActionBlock/ActionInspect: handles internally and returns nil PassthroughConn.
func (p *Proxy) inspectTLS(ctx context.Context, pconn *peekConn, dst netip.AddrPort, src SourceInfo) (InspectResult, error) {
// The first 5 bytes (TLS record header) are already peeked.
// Extend to read the full TLS record so bytes remain in the buffer for passthrough.
peeked := pconn.Peeked()
recordLen := int(peeked[3])<<8 | int(peeked[4])
if _, err := pconn.PeekMore(5 + recordLen); err != nil {
return InspectResult{}, fmt.Errorf("read TLS record: %w", err)
}
hello, err := parseClientHelloFromBytes(pconn.Peeked())
if err != nil {
return InspectResult{}, fmt.Errorf("parse ClientHello: %w", err)
}
sni := hello.SNI
proto := protoFromALPN(hello.ALPN)
// Connection-level evaluation: pass empty path.
action := p.evaluateAction(src.IP, sni, dst, proto, "")
// If any rule for this domain has path patterns, force inspect so paths can
// be checked per-request after MITM decryption.
if action == ActionAllow && p.rules.HasPathRulesForDomain(sni) {
p.log.Debugf("upgrading to inspect for %s (path rules exist)", sni.PunycodeString())
action = ActionInspect
}
// Snapshot cert provider under lock for use in this connection.
p.mu.RLock()
certs := p.certs
p.mu.RUnlock()
switch action {
case ActionBlock:
p.log.Debugf("block: TLS to %s (SNI=%s)", dst, sni.PunycodeString())
if certs != nil {
return InspectResult{Action: ActionBlock}, p.tlsBlockPage(ctx, pconn, sni, certs)
}
return InspectResult{Action: ActionBlock}, ErrBlocked
case ActionAllow:
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
case ActionInspect:
if certs == nil {
p.log.Warnf("allow: %s (inspect requested but no MITM CA configured)", sni.PunycodeString())
return InspectResult{Action: ActionAllow, PassthroughConn: pconn}, nil
}
err := p.tlsMITM(ctx, pconn, dst, sni, src, certs)
return InspectResult{Action: ActionInspect}, err
default:
p.log.Warnf("block: unknown action %q for %s", action, sni.PunycodeString())
return InspectResult{Action: ActionBlock}, ErrBlocked
}
}
// tlsBlockPage completes a MITM TLS handshake with the client using a dynamic
// certificate, then serves an HTTP 403 block page so the user sees a clear
// message instead of a cryptic SSL error.
func (p *Proxy) tlsBlockPage(ctx context.Context, pconn *peekConn, sni domain.Domain, certs *CertProvider) error {
hostname := sni.PunycodeString()
// Force HTTP/1.1 only: block pages are simple responses, no need for h2
tlsCfg := certs.GetTLSConfig()
tlsCfg.NextProtos = []string{"http/1.1"}
clientTLS := tls.Server(pconn, tlsCfg)
if err := clientTLS.HandshakeContext(ctx); err != nil {
// Client may not trust our CA, handshake fails. That's expected.
return fmt.Errorf("block page TLS handshake for %s: %w", hostname, err)
}
defer func() {
if err := clientTLS.Close(); err != nil {
p.log.Debugf("close block page TLS for %s: %v", hostname, err)
}
}()
writeBlockResponse(clientTLS, nil, sni)
return ErrBlocked
}
// tlsPassthrough connects to the destination and relays encrypted traffic
// without decryption. The peeked ClientHello bytes are replayed.
func (p *Proxy) tlsPassthrough(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain) error {
remote, err := p.dialTCP(ctx, dst)
if err != nil {
return fmt.Errorf("dial %s: %w", dst, err)
}
defer func() {
if err := remote.Close(); err != nil {
p.log.Debugf("close remote for %s: %v", dst, err)
}
}()
p.log.Tracef("allow: TLS passthrough to %s (SNI=%s)", dst, sni.PunycodeString())
return relay(ctx, pconn, remote)
}
// tlsMITM terminates the client TLS connection with a dynamic certificate,
// establishes a new TLS connection to the real destination, and runs the
// HTTP inspection pipeline on the decrypted traffic.
func (p *Proxy) tlsMITM(ctx context.Context, pconn *peekConn, dst netip.AddrPort, sni domain.Domain, src SourceInfo, certs *CertProvider) error {
hostname := sni.PunycodeString()
// TLS handshake with client using dynamic cert
clientTLS := tls.Server(pconn, certs.GetTLSConfig())
if err := clientTLS.HandshakeContext(ctx); err != nil {
return fmt.Errorf("client TLS handshake for %s: %w", hostname, err)
}
defer func() {
if err := clientTLS.Close(); err != nil {
p.log.Debugf("close client TLS for %s: %v", hostname, err)
}
}()
// TLS connection to real destination
remoteTLS, err := p.dialTLS(ctx, dst, hostname)
if err != nil {
return fmt.Errorf("dial TLS %s (%s): %w", dst, hostname, err)
}
defer func() {
if err := remoteTLS.Close(); err != nil {
p.log.Debugf("close remote TLS for %s: %v", hostname, err)
}
}()
negotiatedProto := clientTLS.ConnectionState().NegotiatedProtocol
p.log.Tracef("inspect: MITM established for %s (proto=%s)", hostname, negotiatedProto)
return p.inspectHTTP(ctx, clientTLS, remoteTLS, dst, sni, src, negotiatedProto)
}
// dialTLS connects to the destination with TLS, verifying the real server certificate.
func (p *Proxy) dialTLS(ctx context.Context, dst netip.AddrPort, serverName string) (net.Conn, error) {
rawConn, err := p.dialTCP(ctx, dst)
if err != nil {
return nil, err
}
tlsConn := tls.Client(rawConn, &tls.Config{
ServerName: serverName,
NextProtos: []string{"h2", "http/1.1"},
MinVersion: tls.VersionTLS12,
})
if err := tlsConn.HandshakeContext(ctx); err != nil {
if closeErr := rawConn.Close(); closeErr != nil {
p.log.Debugf("close raw conn after TLS handshake failure: %v", closeErr)
}
return nil, fmt.Errorf("TLS handshake with %s: %w", serverName, err)
}
return tlsConn, nil
}
// protoFromALPN maps TLS ALPN protocol names to proxy ProtoType.
// Falls back to ProtoHTTPS when no recognized ALPN is present.
func protoFromALPN(alpn []string) ProtoType {
for _, p := range alpn {
switch p {
case "h2":
return ProtoH2
case "h3": // unlikely in TLS, but handle anyway
return ProtoH3
}
}
// No ALPN or only "http/1.1": treat as HTTPS
return ProtoHTTPS
}
// relay copies data bidirectionally between client and remote until one
// side closes or the context is cancelled.
func relay(ctx context.Context, client, remote net.Conn) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()
errCh := make(chan error, 2)
go func() {
_, err := io.Copy(remote, client)
cancel()
errCh <- err
}()
go func() {
_, err := io.Copy(client, remote)
cancel()
errCh <- err
}()
var firstErr error
for range 2 {
if err := <-errCh; err != nil && firstErr == nil {
if !isClosedErr(err) {
firstErr = err
}
}
}
return firstErr
}
// evaluateAction runs rule evaluation and resolves the effective action.
// Pass empty path for connection-level (TLS), non-empty for request-level (HTTP).
func (p *Proxy) evaluateAction(src netip.Addr, sni domain.Domain, dst netip.AddrPort, proto ProtoType, path string) Action {
return p.rules.Evaluate(src, sni, dst.Addr(), dst.Port(), proto, path)
}
// dialTCP dials the destination, blocking connections to loopback, link-local,
// multicast, and WG overlay network addresses.
func (p *Proxy) dialTCP(ctx context.Context, dst netip.AddrPort) (net.Conn, error) {
ip := dst.Addr().Unmap()
if err := p.validateDialTarget(ip); err != nil {
return nil, fmt.Errorf("dial %s: %w", dst, err)
}
return p.dialer.DialContext(ctx, "tcp", dst.String())
}
// validateDialTarget blocks destinations that should never be dialed by the proxy.
// Mirrors the route validation in systemops.validateRoute.
func (p *Proxy) validateDialTarget(addr netip.Addr) error {
switch {
case !addr.IsValid():
return fmt.Errorf("invalid address")
case addr.IsLoopback():
return fmt.Errorf("loopback address not allowed")
case addr.IsLinkLocalUnicast(), addr.IsLinkLocalMulticast(), addr.IsInterfaceLocalMulticast():
return fmt.Errorf("link-local address not allowed")
case addr.IsMulticast():
return fmt.Errorf("multicast address not allowed")
case p.wgNetwork.IsValid() && p.wgNetwork.Contains(addr):
return fmt.Errorf("overlay network address not allowed")
case p.localIPs != nil && p.localIPs.IsLocalIP(addr):
return fmt.Errorf("local address not allowed")
}
return nil
}
func isClosedErr(err error) bool {
if err == nil {
return false
}
return err == io.EOF ||
err == io.ErrClosedPipe ||
err == net.ErrClosed ||
err == context.Canceled
}

View File

@@ -201,18 +201,7 @@ Pop $0
Function .onInit
StrCpy $INSTDIR "${INSTALL_DIR}"
; Default autostart to enabled so silent installs (/S) match the interactive default
StrCpy $AutostartEnabled "1"
; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live
; in the 32-bit view. Fall back to it so upgrades still find them.
SetRegView 64
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
${If} $R0 == ""
SetRegView 32
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
SetRegView 64
${EndIf}
${If} $R0 != ""
# if silent install jump to uninstall step
IfSilent uninstall
@@ -225,10 +214,6 @@ ${If} $R0 != ""
${EndIf}
FunctionEnd
Function un.onInit
SetRegView 64
FunctionEnd
######################################################################
Section -MainProgram
${INSTALL_TYPE}
@@ -243,7 +228,6 @@ Section -MainProgram
!else
File /r "..\\dist\\netbird_windows_amd64\\"
!endif
File "..\\client\\ui\\assets\\netbird.png"
SectionEnd
######################################################################
@@ -263,11 +247,9 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user"
${EndIf}
@@ -301,8 +283,6 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox
@@ -341,7 +321,6 @@ DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}"
DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM

View File

@@ -94,7 +94,6 @@ func (c *ConnectClient) RunOnAndroid(
dnsAddresses []netip.AddrPort,
dnsReadyListener dns.ReadyListener,
stateFilePath string,
cacheDir string,
) error {
// in case of non Android os these variables will be nil
mobileDependency := MobileDependency{
@@ -104,7 +103,6 @@ func (c *ConnectClient) RunOnAndroid(
HostDNSAddresses: dnsAddresses,
DnsReadyListener: dnsReadyListener,
StateFilePath: stateFilePath,
TempDir: cacheDir,
}
return c.run(mobileDependency, nil, "")
}
@@ -333,10 +331,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp)
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
relayURLs = override
}
peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
@@ -344,7 +338,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
log.Error(err)
return wrapErr(err)
}
engineConfig.TempDir = mobileDependency.TempDir
relayManager := relayClient.NewManager(engineCtx, relayURLs, myPrivateKey.PublicKey().String(), engineConfig.MTU)
c.statusRecorder.SetRelayMgr(relayManager)
@@ -569,6 +562,9 @@ func createEngineConfig(key wgtypes.Key, config *profilemanager.Config, peerConf
MTU: selectMTU(config.MTU, peerConfig.Mtu),
LogPath: logPath,
InspectionCACertPath: config.InspectionCACertPath,
InspectionCAKeyPath: config.InspectionCAKeyPath,
ProfileConfig: config,
}

View File

@@ -16,6 +16,7 @@ import (
"path/filepath"
"runtime"
"runtime/pprof"
"slices"
"sort"
"strings"
"time"
@@ -30,6 +31,7 @@ import (
"github.com/netbirdio/netbird/client/internal/updater/installer"
nbstatus "github.com/netbirdio/netbird/client/status"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util"
)
const readmeContent = `Netbird debug bundle
@@ -61,7 +63,6 @@ allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information.
cpu.prof: CPU profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data.
Anonymization Process
@@ -233,9 +234,7 @@ type BundleGenerator struct {
statusRecorder *peer.Status
syncResponse *mgmProto.SyncResponse
logPath string
tempDir string
cpuProfile []byte
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter
@@ -257,10 +256,8 @@ type GeneratorDependencies struct {
StatusRecorder *peer.Status
SyncResponse *mgmProto.SyncResponse
LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte
CapturePath string
RefreshStatus func()
RefreshStatus func() // Optional callback to refresh status before bundle generation
ClientMetrics MetricsExporter
}
@@ -278,9 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
statusRecorder: deps.StatusRecorder,
syncResponse: deps.SyncResponse,
logPath: deps.LogPath,
tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile,
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics,
@@ -292,7 +287,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
// Generate creates a debug bundle and returns the location.
func (g *BundleGenerator) Generate() (resp string, err error) {
bundlePath, err := os.CreateTemp(g.tempDir, "netbird.debug.*.zip")
bundlePath, err := os.CreateTemp("", "netbird.debug.*.zip")
if err != nil {
return "", fmt.Errorf("create zip file: %w", err)
}
@@ -350,10 +345,6 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
}
if err := g.addCaptureFile(); err != nil {
log.Errorf("failed to add capture file to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err)
}
@@ -382,8 +373,15 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add wg show output: %v", err)
}
if err := g.addPlatformLog(); err != nil {
log.Errorf("failed to add logs to debug bundle: %v", err)
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs as fallback: %v", err)
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
log.Errorf("failed to add systemd logs: %v", err)
}
if err := g.addUpdateLogs(); err != nil {
@@ -607,12 +605,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
}
if g.internalConfig.DisableSSHAuth != nil {
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
}
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
@@ -639,7 +631,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
}
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
}
func (g *BundleGenerator) addProf() (err error) {
@@ -684,29 +675,6 @@ func (g *BundleGenerator) addCPUProfile() error {
return nil
}
func (g *BundleGenerator) addCaptureFile() error {
if g.capturePath == "" {
return nil
}
if g.anonymize {
log.Info("skipping capture file in anonymized bundle (contains raw packet data)")
return nil
}
f, err := os.Open(g.capturePath)
if err != nil {
return fmt.Errorf("open capture file: %w", err)
}
defer f.Close()
if err := g.addFileToZip(f, "capture.pcap"); err != nil {
return fmt.Errorf("add capture file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true)

View File

@@ -1,41 +0,0 @@
//go:build android
package debug
import (
"fmt"
"io"
"os/exec"
log "github.com/sirupsen/logrus"
)
func (g *BundleGenerator) addPlatformLog() error {
cmd := exec.Command("/system/bin/logcat", "-d")
stdout, err := cmd.StdoutPipe()
if err != nil {
return fmt.Errorf("logcat stdout pipe: %w", err)
}
if err := cmd.Start(); err != nil {
return fmt.Errorf("start logcat: %w", err)
}
var logReader io.Reader = stdout
if g.anonymize {
var pw *io.PipeWriter
logReader, pw = io.Pipe()
go anonymizeLog(stdout, pw, g.anonymizer)
}
if err := g.addFileToZip(logReader, "logcat.txt"); err != nil {
return fmt.Errorf("add logcat to zip: %w", err)
}
if err := cmd.Wait(); err != nil {
return fmt.Errorf("wait logcat: %w", err)
}
log.Debug("added logcat output to debug bundle")
return nil
}

View File

@@ -1,25 +0,0 @@
//go:build !android
package debug
import (
"slices"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/util"
)
func (g *BundleGenerator) addPlatformLog() error {
if g.logPath != "" && !slices.Contains(util.SpecialLogs, g.logPath) {
if err := g.addLogfile(); err != nil {
log.Errorf("failed to add log file to debug bundle: %v", err)
if err := g.trySystemdLogFallback(); err != nil {
return err
}
}
} else if err := g.trySystemdLogFallback(); err != nil {
return err
}
return nil
}

View File

@@ -5,21 +5,16 @@ import (
"bytes"
"encoding/json"
"net"
"net/url"
"os"
"path/filepath"
"reflect"
"strings"
"testing"
"time"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
@@ -476,8 +471,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"HOME": "/root",
"PATH": "/usr/bin",
"HOME": "/root",
"PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug",
},
},
@@ -494,9 +489,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false,
input: map[string]any{
jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
"NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info",
},
},
check: func(t *testing.T, params map[string]any) {
@@ -771,127 +766,3 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
}
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
// profilemanager.Config is either rendered in the debug bundle or explicitly
// excluded. When a new field is added to Config, this test fails until the
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
// the excluded set with a justification.
func TestAddConfig_AllFieldsCovered(t *testing.T) {
excluded := map[string]string{
"PrivateKey": "sensitive: WireGuard private key",
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
}
mURL, _ := url.Parse("https://api.example.com:443")
aURL, _ := url.Parse("https://admin.example.com:443")
bTrue := true
iVal := 42
cfg := &profilemanager.Config{
PrivateKey: "priv",
PreSharedKey: "psk",
ManagementURL: mURL,
AdminURL: aURL,
WgIface: "wt0",
WgPort: 51820,
NetworkMonitor: &bTrue,
IFaceBlackList: []string{"eth0"},
DisableIPv6Discovery: true,
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,
EnableSSHRemotePortForwarding: &bTrue,
DisableSSHAuth: &bTrue,
SSHJWTCacheTTL: &iVal,
DisableClientRoutes: true,
DisableServerRoutes: true,
DisableDNS: true,
DisableFirewall: true,
BlockLANAccess: true,
BlockInbound: true,
DisableNotifications: &bTrue,
DNSLabels: domain.List{},
SSHKey: "sshkey",
NATExternalIPs: []string{"1.2.3.4"},
CustomDNSAddress: "1.1.1.1:53",
DisableAutoConnect: true,
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
MTU: 1280,
}
for _, anonymize := range []bool{false, true} {
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
g := &BundleGenerator{
anonymizer: newAnonymizerForTest(),
internalConfig: cfg,
anonymize: anonymize,
}
var sb strings.Builder
g.addCommonConfigFields(&sb)
rendered := sb.String() + renderAddConfigSpecific(g)
val := reflect.ValueOf(cfg).Elem()
typ := val.Type()
var missing []string
for i := 0; i < typ.NumField(); i++ {
name := typ.Field(i).Name
if _, ok := excluded[name]; ok {
continue
}
if !strings.Contains(rendered, name+":") {
missing = append(missing, name)
}
}
if len(missing) > 0 {
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
"Either render the field in addCommonConfigFields/addConfig, "+
"or add it to the excluded map with a justification.", missing)
}
})
}
}
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
// production shape without needing to write an actual zip.
func renderAddConfigSpecific(g *BundleGenerator) string {
var sb strings.Builder
if g.anonymize {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
}
} else {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
}
}
return sb.String()
}
func newAnonymizerForTest() *anonymize.Anonymizer {
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
}

View File

@@ -3,12 +3,10 @@ package debug
import (
"context"
"errors"
"net"
"net/http"
"os"
"path/filepath"
"testing"
"time"
"github.com/stretchr/testify/require"
@@ -21,10 +19,8 @@ func TestUpload(t *testing.T) {
t.Skip("Skipping upload test on docker ci")
}
testDir := t.TempDir()
addr := reserveLoopbackPort(t)
testURL := "http://" + addr
testURL := "http://localhost:8080"
t.Setenv("SERVER_URL", testURL)
t.Setenv("SERVER_ADDRESS", addr)
t.Setenv("STORE_DIR", testDir)
srv := server.NewServer()
go func() {
@@ -37,7 +33,6 @@ func TestUpload(t *testing.T) {
t.Errorf("Failed to stop server: %v", err)
}
})
waitForServer(t, addr)
file := filepath.Join(t.TempDir(), "tmpfile")
fileContent := []byte("test file content")
@@ -52,30 +47,3 @@ func TestUpload(t *testing.T) {
require.NoError(t, err)
require.Equal(t, fileContent, createdFileContent)
}
// reserveLoopbackPort binds an ephemeral port on loopback to learn a free
// address, then releases it so the server under test can rebind. The close/
// rebind window is racy in theory; on loopback with a kernel-assigned port
// it's essentially never contended in practice.
func reserveLoopbackPort(t *testing.T) string {
t.Helper()
l, err := net.Listen("tcp", "127.0.0.1:0")
require.NoError(t, err)
addr := l.Addr().String()
require.NoError(t, l.Close())
return addr
}
func waitForServer(t *testing.T, addr string) {
t.Helper()
deadline := time.Now().Add(5 * time.Second)
for time.Now().Before(deadline) {
c, err := net.DialTimeout("tcp", addr, 100*time.Millisecond)
if err == nil {
_ = c.Close()
return
}
time.Sleep(20 * time.Millisecond)
}
t.Fatalf("server did not start listening on %s in time", addr)
}

View File

@@ -13,7 +13,6 @@ import (
const (
defaultResolvConfPath = "/etc/resolv.conf"
nsswitchConfPath = "/etc/nsswitch.conf"
)
type resolvConf struct {

View File

@@ -1,10 +1,7 @@
package dns
import (
"context"
"fmt"
"math"
"net"
"slices"
"strconv"
"strings"
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
}
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 {
return
}
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
// Try handlers in priority order
for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) {
continue
}
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime))
}
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch {
case entry.Pattern == ".":
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
}
}
}
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,15 +1,11 @@
package dns_test
import (
"context"
"net"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
})
}
}
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -46,12 +46,12 @@ type restoreHostManager interface {
}
func newHostManager(wgInterface string) (hostManager, error) {
osManager, reason, err := getOSDNSManagerType()
osManager, err := getOSDNSManagerType()
if err != nil {
return nil, fmt.Errorf("get os dns manager type: %w", err)
}
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
log.Infof("System DNS manager discovered: %s", osManager)
mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil {
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
}
}
func getOSDNSManagerType() (osManagerType, string, error) {
resolved := isSystemdResolvedRunning()
nss := isLibnssResolveUsed()
stub := checkStub()
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
// that go through nss-resolve, and in foreign mode they can loop back
// through resolved as an upstream.
if resolved && (nss || stub) {
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
}
mgr, reason, rejected, err := scanResolvConfHeader()
if err != nil {
return 0, "", err
}
if reason != "" {
return mgr, reason, nil
}
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
if len(rejected) > 0 {
fallback += "; rejected: " + strings.Join(rejected, ", ")
}
return fileManager, fallback, nil
}
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
// matching manager. If reason is empty the caller should pick file mode and
// use rejected for diagnostics.
func scanResolvConfHeader() (osManagerType, string, []string, error) {
func getOSDNSManagerType() (osManagerType, error) {
file, err := os.Open(defaultResolvConfPath)
if err != nil {
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
}
defer func() {
if cerr := file.Close(); cerr != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
if err := file.Close(); err != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
}
}()
var rejected []string
scanner := bufio.NewScanner(file)
for scanner.Scan() {
text := scanner.Text()
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
continue
}
if text[0] != '#' {
break
return fileManager, nil
}
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
return mgr, reason, nil, nil
} else if rej != "" {
rejected = append(rejected, rej)
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, nil
}
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
return fileManager, nil
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
}
}
if err := scanner.Err(); err != nil && err != io.EOF {
return 0, "", nil, fmt.Errorf("scan: %w", err)
return 0, fmt.Errorf("scan: %w", err)
}
return 0, "", rejected, nil
return fileManager, nil
}
// matchResolvConfHeader inspects a single comment line. Returns either a
// definitive (manager, reason) or a non-empty rejected diagnostic.
func matchResolvConfHeader(text string) (osManagerType, string, string) {
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, "netbird-managed resolv.conf header detected", ""
}
if strings.Contains(text, "NetworkManager") {
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, "NetworkManager header + supported version on dbus", ""
}
return 0, "", "NetworkManager header (no dbus or unsupported version)"
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
}
return resolvConfManager, "resolvconf header detected", ""
}
return 0, "", ""
}
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
// into file mode while resolved is active.
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
func checkStub() bool {
rConf, err := parseDefaultResolvConf()
if err != nil {
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
log.Warnf("failed to parse resolv conf: %s", err)
return true
}
@@ -178,36 +139,3 @@ func checkStub() bool {
return false
}
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
// delegated to systemd-resolved regardless of /etc/resolv.conf.
func isLibnssResolveUsed() bool {
bs, err := os.ReadFile(nsswitchConfPath)
if err != nil {
log.Debugf("read %s: %v", nsswitchConfPath, err)
return false
}
return parseNsswitchResolveAhead(bs)
}
func parseNsswitchResolveAhead(data []byte) bool {
for _, line := range strings.Split(string(data), "\n") {
if i := strings.IndexByte(line, '#'); i >= 0 {
line = line[:i]
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "hosts:" {
continue
}
for _, module := range fields[1:] {
switch module {
case "dns":
return false
case "resolve":
return true
}
}
}
return false
}

View File

@@ -1,76 +0,0 @@
//go:build (linux && !android) || freebsd
package dns
import "testing"
func TestParseNsswitchResolveAhead(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "resolve before dns with action token",
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
want: true,
},
{
name: "dns before resolve",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
want: false,
},
{
name: "debian default with only dns",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
want: false,
},
{
name: "neither resolve nor dns",
in: "hosts: files myhostname\n",
want: false,
},
{
name: "no hosts line",
in: "passwd: files systemd\ngroup: files systemd\n",
want: false,
},
{
name: "empty",
in: "",
want: false,
},
{
name: "comments and blank lines ignored",
in: "# comment\n\n# another\nhosts: resolve dns\n",
want: true,
},
{
name: "trailing inline comment",
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
want: true,
},
{
name: "hosts token must be the first field",
in: " hosts: resolve dns\n",
want: true,
},
{
name: "other db line mentioning resolve is ignored",
in: "networks: resolve\nhosts: dns\n",
want: false,
},
{
name: "only resolve, no dns",
in: "hosts: files resolve\n",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -2,83 +2,40 @@ package mgmt
import (
"context"
"errors"
"fmt"
"net"
"net/netip"
"net/url"
"os"
"slices"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/miekg/dns"
log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain"
)
const (
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
const dnsTimeout = 5 * time.Second
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
// Resolver caches critical NetBird infrastructure domains
type Resolver struct {
records map[dns.Question]*cachedRecord
records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex
}
chain ChainResolver
chainMaxPriority int
refreshGroup singleflight.Group
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
type ipsResponse struct {
ips []netip.Addr
err error
}
// NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver {
return &Resolver{
records: make(map[dns.Question]*cachedRecord),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
records: make(map[dns.Question][]dns.RR),
}
}
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver"
}
// SetChainResolver wires the handler chain used to refresh stale cache entries.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
// ServeDNS implements dns.Handler interface.
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 {
m.continueToNext(w, r)
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
}
m.mutex.RLock()
cached, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
records, found := m.records[question]
m.mutex.RUnlock()
if !found {
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return
}
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{}
resp.SetReply(r)
resp.Authoritative = false
resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
}
}
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
// AddDomain manually adds a domain to cache by resolving it.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel()
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
if errA != nil && errAAAA != nil {
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
return err
}
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
if err := errors.Join(errA, errAAAA); err != nil {
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
var aRecords, aaaaRecords []dns.RR
for _, ip := range ips {
if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
}
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
}
now := time.Now()
m.mutex.Lock()
defer m.mutex.Unlock()
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
if len(aRecords) > 0 {
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords))
return nil
}
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
// untouched on error. Caller holds m.mutex.
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
resultChan := make(chan *ipsResponse, 1)
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
// unique in-flight key; bursty stale hits share its channel. expected is the
// cachedRecord pointer observed by the caller; the refresh only mutates the
// cache if that pointer is still the one stored, so a stale in-flight refresh
// can't clobber a newer entry written by AddDomain or a competing refresh.
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
go func() {
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
resultChan <- &ipsResponse{
err: err,
ips: ips,
}
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
return err
}()
var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
}
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
if len(records) == 0 {
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
if resp.err != nil {
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
}
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
return resp.ips, nil
}
// PopulateFromConfig extracts and caches domains from the client configuration.
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock()
defer m.mutex.Unlock()
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
delete(m.records, qA)
delete(m.records, qAAAA)
delete(m.refreshing, qA)
delete(m.refreshing, qAAAA)
aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString())
return nil
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains
}
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -1,408 +0,0 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,7 +6,6 @@ import (
"net/url"
"strings"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains())
}
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct {
name string

View File

@@ -212,7 +212,6 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{
ctx: ctx,

View File

@@ -26,13 +26,12 @@ import (
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/firewalld"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/udpmux"
"github.com/netbirdio/netbird/client/inspect"
"github.com/netbirdio/netbird/client/internal/acl"
"github.com/netbirdio/netbird/client/internal/debug"
"github.com/netbirdio/netbird/client/internal/dns"
@@ -69,7 +68,6 @@ import (
signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/capture"
)
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -139,11 +137,16 @@ type EngineConfig struct {
MTU uint16
// InspectionCACertPath is a local CA cert for transparent proxy MITM.
// Takes priority over management-pushed CA.
InspectionCACertPath string
// InspectionCAKeyPath is the corresponding private key.
InspectionCAKeyPath string
// for debug bundle generation
ProfileConfig *profilemanager.Config
LogPath string
TempDir string
}
// EngineServices holds the external service dependencies required by the Engine.
@@ -220,14 +223,16 @@ type Engine struct {
portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher
afpacketCapture *capture.AFPacketCapture
// Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex
persistSyncResponse bool
latestSyncResponse *mgmProto.SyncResponse
flowManager nftypes.FlowManager
// transparentProxy is the transparent forward proxy for traffic inspection.
transparentProxy *inspect.Proxy
udpInspectionHookID string
// auto-update
updateManager *updater.Manager
@@ -575,7 +580,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
e.connMgr.Start(e.ctx)
e.srWatcher = guard.NewSRWatcher(e.signal, e.relayManager, e.mobileDep.IFaceDiscover, iceCfg)
e.srWatcher.Start(peer.IsForceRelayed())
e.srWatcher.Start()
e.receiveSignalEvents()
e.receiveManagementEvents()
@@ -609,8 +614,6 @@ func (e *Engine) createFirewall() error {
return nil
}
firewalld.SetParentContext(e.ctx)
var err error
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes, e.config.MTU)
if err != nil {
@@ -948,12 +951,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
return fmt.Errorf("update relay token: %w", err)
}
urls := update.Urls
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
urls = override
}
e.relayManager.UpdateServerURLs(urls)
e.relayManager.UpdateServerURLs(update.Urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries.
@@ -1108,7 +1106,6 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
StatusRecorder: e.statusRecorder,
SyncResponse: syncResponse,
LogPath: e.config.LogPath,
TempDir: e.config.TempDir,
ClientMetrics: e.clientMetrics,
RefreshStatus: func() {
e.RunHealthProbes(true)
@@ -1286,6 +1283,9 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
fwdEntries := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), routes)
e.updateDNSForwarder(dnsRouteFeatureFlag, fwdEntries)
// Transparent proxy
e.updateTransparentProxy(networkMap.GetTransparentProxyConfig())
// Ingress forward rules
forwardingRules, err := e.updateForwardRules(networkMap.GetForwardingRules())
if err != nil {
@@ -1707,13 +1707,10 @@ func (e *Engine) parseNATExternalIPMappings() []string {
}
func (e *Engine) close() {
if e.afpacketCapture != nil {
e.afpacketCapture.Stop()
e.afpacketCapture = nil
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
e.stopTransparentProxy()
if e.wgInterface != nil {
if err := e.wgInterface.Close(); err != nil {
log.Errorf("failed closing Netbird interface %s %v", e.config.WgIfaceName, err)
@@ -2177,62 +2174,6 @@ func (e *Engine) Address() (netip.Addr, error) {
return e.wgInterface.Address().IP, nil
}
// SetCapture sets or clears packet capture on the WireGuard device.
// On userspace WireGuard, it taps the FilteredDevice directly.
// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture.
// Pass nil to disable capture.
func (e *Engine) SetCapture(pc device.PacketCapture) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
intf := e.wgInterface
if intf == nil {
return errors.New("wireguard interface not initialized")
}
if e.afpacketCapture != nil {
e.afpacketCapture.Stop()
e.afpacketCapture = nil
}
dev := intf.GetDevice()
if dev != nil {
dev.SetCapture(pc)
e.setForwarderCapture(pc)
return nil
}
// Kernel mode: no FilteredDevice. Use AF_PACKET on Linux.
if pc == nil {
return nil
}
sess, ok := pc.(*capture.Session)
if !ok {
return errors.New("filtered device not available and AF_PACKET requires *capture.Session")
}
afc := capture.NewAFPacketCapture(intf.Name(), sess)
if err := afc.Start(); err != nil {
return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err)
}
e.afpacketCapture = afc
return nil
}
// setForwarderCapture propagates capture to the USP filter's forwarder endpoint.
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
func (e *Engine) setForwarderCapture(pc device.PacketCapture) {
if e.firewall == nil {
return
}
type forwarderCapturer interface {
SetPacketCapture(pc forwarder.PacketCapture)
}
if fc, ok := e.firewall.(forwarderCapturer); ok {
fc.SetPacketCapture(pc)
}
}
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules")
@@ -2454,8 +2395,6 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
}
}
relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP())
offerAnswer := peer.OfferAnswer{
IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag,
@@ -2466,23 +2405,7 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
RelaySrvIP: relayIP,
SessionID: sessionID,
}
return &offerAnswer, nil
}
// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a
// netip.Addr. Returns the zero value for empty input and logs a warning
// for malformed payloads.
func decodeRelayIP(b []byte) netip.Addr {
if len(b) == 0 {
return netip.Addr{}
}
ip, ok := netip.AddrFromSlice(b)
if !ok {
log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b))
return netip.Addr{}
}
return ip.Unmap()
}

View File

@@ -55,7 +55,6 @@ import (
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/server"
"github.com/netbirdio/netbird/management/server/activity"
nbcache "github.com/netbirdio/netbird/management/server/cache"
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
"github.com/netbirdio/netbird/management/server/permissions"
"github.com/netbirdio/netbird/management/server/settings"
@@ -1635,12 +1634,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
peersManager := peers.NewManager(store, permissionsManager)
jobManager := job.NewJobManager(nil, store, peersManager)
cacheStore, err := nbcache.NewStore(context.Background(), 100*time.Millisecond, 300*time.Millisecond, 100)
if err != nil {
return nil, "", err
}
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore, cacheStore)
ia, _ := integrations.NewIntegratedValidator(context.Background(), peersManager, nil, eventStore)
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
require.NoError(t, err)
@@ -1662,7 +1656,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
updateManager := update_channel.NewPeersUpdateManager(metrics)
requestBuffer := server.NewAccountRequestBuffer(context.Background(), store)
networkMapController := controller.NewController(context.Background(), store, metrics, updateManager, requestBuffer, server.MockIntegratedValidator{}, settingsMockManager, "netbird.selfhosted", port_forwarding.NewControllerMock(), manager.NewEphemeralManager(store, peersManager), config)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false, cacheStore)
accountManager, err := server.BuildManager(context.Background(), config, store, networkMapController, jobManager, nil, "", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager, false)
if err != nil {
return nil, "", err
}
@@ -1671,7 +1665,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil {
return nil, "", err
}
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil {
return nil, "", err
}

View File

@@ -0,0 +1,571 @@
package internal
import (
"crypto/ecdsa"
"crypto/elliptic"
"crypto/rand"
"crypto/x509"
"crypto/x509/pkix"
"encoding/pem"
"fmt"
"math/big"
"net/netip"
"net/url"
"os"
"strconv"
"time"
log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/inspect"
"github.com/netbirdio/netbird/client/internal/acl/id"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
// updateTransparentProxy processes transparent proxy configuration from the network map.
func (e *Engine) updateTransparentProxy(cfg *mgmProto.TransparentProxyConfig) {
if cfg == nil || !cfg.Enabled {
if cfg == nil {
log.Tracef("inspect: config is nil")
} else {
log.Tracef("inspect: config disabled")
}
// Only stop if explicitly disabled. Don't stop on nil config to avoid
// a gap during policy edits where management briefly pushes empty config.
if cfg != nil && !cfg.Enabled {
e.stopTransparentProxy()
}
return
}
log.Debugf("inspect: config received: enabled=%v mode=%v default_action=%v rules=%d has_ca=%v",
cfg.Enabled, cfg.Mode, cfg.DefaultAction, len(cfg.Rules), len(cfg.CaCertPem) > 0)
// BlockInbound prevents adding TPROXY rules since kernel TPROXY bypasses ACLs.
// The userspace forwarder path still works as it operates within the forwarder hook.
if e.config.BlockInbound {
log.Warnf("inspect: BlockInbound is set, skipping redirect rules (userspace path still active)")
}
proxyConfig, err := toProxyConfig(cfg)
if err != nil {
log.Errorf("inspect: parse config: %v", err)
e.stopTransparentProxy()
return
}
// CA priority: local config > management-pushed > auto-generated self-signed.
// Local wins over mgmt to prevent compromised management from injecting a CA.
e.resolveInspectionCA(&proxyConfig)
if e.transparentProxy != nil {
// Mode change requires full recreate (envoy lifecycle, listener changes).
if proxyConfig.Mode != e.transparentProxy.Mode() {
log.Infof("inspect: mode changed to %s, recreating engine", proxyConfig.Mode)
e.stopTransparentProxy()
} else {
e.transparentProxy.UpdateConfig(proxyConfig)
e.syncTProxyRules(proxyConfig)
e.syncUDPInspectionHook()
return
}
}
if e.wgInterface != nil {
proxyConfig.WGNetwork = e.wgInterface.Address().Network
proxyConfig.ListenAddr = netip.AddrPortFrom(
e.wgInterface.Address().IP.Unmap(),
proxyConfig.ListenAddr.Port(),
)
}
// Pass local IP checker for SSRF prevention
if checker, ok := e.firewall.(inspect.LocalIPChecker); ok {
proxyConfig.LocalIPChecker = checker
}
p, err := inspect.New(e.ctx, log.WithField("component", "inspect"), proxyConfig)
if err != nil {
log.Errorf("inspect: start engine: %v", err)
return
}
e.transparentProxy = p
e.attachProxyToForwarder(p)
e.syncTProxyRules(proxyConfig)
e.syncUDPInspectionHook()
log.Infof("inspect: engine started (mode=%s, rules=%d)", proxyConfig.Mode, len(proxyConfig.Rules))
}
// stopTransparentProxy shuts down the transparent proxy and removes interception.
func (e *Engine) stopTransparentProxy() {
if e.transparentProxy == nil {
return
}
e.attachProxyToForwarder(nil)
e.removeTProxyRule()
e.removeUDPInspectionHook()
if err := e.transparentProxy.Close(); err != nil {
log.Debugf("inspect: close engine: %v", err)
}
e.transparentProxy = nil
log.Info("inspect: engine stopped")
}
const tproxyRuleID = "tproxy-redirect"
// syncTProxyRules adds a TPROXY rule via the firewall manager to intercept
// matching traffic on the WG interface and redirect it to the proxy socket.
func (e *Engine) syncTProxyRules(config inspect.Config) {
if e.config.BlockInbound {
e.removeTProxyRule()
return
}
var listenPort uint16
if e.transparentProxy != nil {
listenPort = e.transparentProxy.ListenPort()
}
if listenPort == 0 {
e.removeTProxyRule()
return
}
if e.firewall == nil {
return
}
dstPorts := make([]uint16, len(config.RedirectPorts))
copy(dstPorts, config.RedirectPorts)
log.Debugf("inspect: syncing redirect rules: listen port %d, redirect ports %v, sources %v",
listenPort, dstPorts, config.RedirectSources)
if err := e.firewall.AddTProxyRule(tproxyRuleID, config.RedirectSources, dstPorts, listenPort); err != nil {
log.Errorf("inspect: add redirect rule: %v", err)
return
}
}
// removeTProxyRule removes the TPROXY redirect rule.
func (e *Engine) removeTProxyRule() {
if e.firewall == nil {
return
}
if err := e.firewall.RemoveTProxyRule(tproxyRuleID); err != nil {
log.Debugf("inspect: remove redirect rule: %v", err)
}
}
// syncUDPInspectionHook registers a UDP packet hook on port 443 for QUIC SNI blocking.
// The hook is called by the USP filter for each UDP packet matching the port,
// allowing the inspection engine to extract QUIC SNI and block by domain.
func (e *Engine) syncUDPInspectionHook() {
e.removeUDPInspectionHook()
if e.firewall == nil || e.transparentProxy == nil {
return
}
p := e.transparentProxy
hookID := e.firewall.AddUDPInspectionHook(443, func(packet []byte) bool {
srcIP, dstIP, dstPort, udpPayload, ok := parseUDPPacket(packet)
if !ok {
return false
}
src := inspect.SourceInfo{IP: srcIP}
dst := netip.AddrPortFrom(dstIP, dstPort)
action := p.HandleUDPPacket(udpPayload, dst, src)
return action == inspect.ActionBlock
})
e.udpInspectionHookID = hookID
log.Debugf("inspect: registered UDP inspection hook on port 443 (id=%s)", hookID)
}
// removeUDPInspectionHook removes the QUIC inspection hook.
func (e *Engine) removeUDPInspectionHook() {
if e.udpInspectionHookID == "" || e.firewall == nil {
return
}
e.firewall.RemoveUDPInspectionHook(e.udpInspectionHookID)
e.udpInspectionHookID = ""
}
// parseUDPPacket extracts source/destination IP, destination port, and UDP
// payload from a raw IP packet. Supports both IPv4 and IPv6.
func parseUDPPacket(packet []byte) (srcIP, dstIP netip.Addr, dstPort uint16, payload []byte, ok bool) {
if len(packet) < 1 {
return srcIP, dstIP, 0, nil, false
}
version := packet[0] >> 4
var udpOffset int
switch version {
case 4:
if len(packet) < 20 {
return srcIP, dstIP, 0, nil, false
}
ihl := int(packet[0]&0x0f) * 4
if len(packet) < ihl+8 {
return srcIP, dstIP, 0, nil, false
}
var srcOK, dstOK bool
srcIP, srcOK = netip.AddrFromSlice(packet[12:16])
dstIP, dstOK = netip.AddrFromSlice(packet[16:20])
if !srcOK || !dstOK {
return srcIP, dstIP, 0, nil, false
}
udpOffset = ihl
case 6:
// IPv6 fixed header is 40 bytes. Next header must be UDP (17).
if len(packet) < 48 { // 40 header + 8 UDP
return srcIP, dstIP, 0, nil, false
}
nextHeader := packet[6]
if nextHeader != 17 { // not UDP (may have extension headers)
return srcIP, dstIP, 0, nil, false
}
var srcOK, dstOK bool
srcIP, srcOK = netip.AddrFromSlice(packet[8:24])
dstIP, dstOK = netip.AddrFromSlice(packet[24:40])
if !srcOK || !dstOK {
return srcIP, dstIP, 0, nil, false
}
udpOffset = 40
default:
return srcIP, dstIP, 0, nil, false
}
srcIP = srcIP.Unmap()
dstIP = dstIP.Unmap()
dstPort = uint16(packet[udpOffset+2])<<8 | uint16(packet[udpOffset+3])
payload = packet[udpOffset+8:]
return srcIP, dstIP, dstPort, payload, true
}
// attachProxyToForwarder sets or clears the proxy on the userspace forwarder.
func (e *Engine) attachProxyToForwarder(p *inspect.Proxy) {
type forwarderGetter interface {
GetForwarder() *forwarder.Forwarder
}
if fg, ok := e.firewall.(forwarderGetter); ok {
if fwd := fg.GetForwarder(); fwd != nil {
fwd.SetProxy(p)
}
}
}
// toProxyConfig converts a proto TransparentProxyConfig to the inspect.Config type.
func toProxyConfig(cfg *mgmProto.TransparentProxyConfig) (inspect.Config, error) {
config := inspect.Config{
Enabled: cfg.Enabled,
DefaultAction: toProxyAction(cfg.DefaultAction),
}
switch cfg.Mode {
case mgmProto.TransparentProxyMode_TP_MODE_ENVOY:
config.Mode = inspect.ModeEnvoy
case mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL:
config.Mode = inspect.ModeExternal
default:
config.Mode = inspect.ModeBuiltin
}
if cfg.ExternalProxyUrl != "" {
u, err := url.Parse(cfg.ExternalProxyUrl)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse external proxy URL: %w", err)
}
config.ExternalURL = u
}
for _, s := range cfg.RedirectSources {
prefix, err := netip.ParsePrefix(s)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse redirect source %q: %w", s, err)
}
config.RedirectSources = append(config.RedirectSources, prefix)
}
for _, p := range cfg.RedirectPorts {
config.RedirectPorts = append(config.RedirectPorts, uint16(p))
}
// TPROXY listen port: fixed default, overridable via env var.
if config.Mode == inspect.ModeBuiltin {
port := uint16(inspect.DefaultTProxyPort)
if v := os.Getenv("NB_TPROXY_PORT"); v != "" {
if p, err := strconv.ParseUint(v, 10, 16); err == nil {
port = uint16(p)
} else {
log.Warnf("invalid NB_TPROXY_PORT %q, using default %d", v, inspect.DefaultTProxyPort)
}
}
config.ListenAddr = netip.AddrPortFrom(netip.IPv4Unspecified(), port)
}
for _, r := range cfg.Rules {
rule, err := toProxyRule(r)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse rule %q: %w", r.Id, err)
}
config.Rules = append(config.Rules, rule)
}
if cfg.Icap != nil {
icapCfg, err := toICAPConfig(cfg.Icap)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse ICAP config: %w", err)
}
config.ICAP = icapCfg
}
if len(cfg.CaCertPem) > 0 && len(cfg.CaKeyPem) > 0 {
tlsCfg, err := parseTLSConfig(cfg.CaCertPem, cfg.CaKeyPem)
if err != nil {
return inspect.Config{}, fmt.Errorf("parse TLS config: %w", err)
}
config.TLS = tlsCfg
}
if config.Mode == inspect.ModeEnvoy {
envCfg := &inspect.EnvoyConfig{
BinaryPath: cfg.EnvoyBinaryPath,
AdminPort: uint16(cfg.EnvoyAdminPort),
}
if cfg.EnvoySnippets != nil {
envCfg.Snippets = &inspect.EnvoySnippets{
HTTPFilters: cfg.EnvoySnippets.HttpFilters,
NetworkFilters: cfg.EnvoySnippets.NetworkFilters,
Clusters: cfg.EnvoySnippets.Clusters,
}
}
config.Envoy = envCfg
}
return config, nil
}
func toProxyRule(r *mgmProto.TransparentProxyRule) (inspect.Rule, error) {
rule := inspect.Rule{
ID: id.RuleID(r.Id),
Action: toProxyAction(r.Action),
Priority: int(r.Priority),
}
for _, d := range r.Domains {
dom, err := domain.FromString(d)
if err != nil {
return inspect.Rule{}, fmt.Errorf("parse domain %q: %w", d, err)
}
rule.Domains = append(rule.Domains, dom)
}
for _, n := range r.Networks {
prefix, err := netip.ParsePrefix(n)
if err != nil {
return inspect.Rule{}, fmt.Errorf("parse network %q: %w", n, err)
}
rule.Networks = append(rule.Networks, prefix)
}
for _, p := range r.Ports {
rule.Ports = append(rule.Ports, uint16(p))
}
for _, proto := range r.Protocols {
rule.Protocols = append(rule.Protocols, toProxyProtoType(proto))
}
rule.Paths = r.Paths
return rule, nil
}
func toProxyProtoType(p mgmProto.TransparentProxyProtocol) inspect.ProtoType {
switch p {
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTP:
return inspect.ProtoHTTP
case mgmProto.TransparentProxyProtocol_TP_PROTO_HTTPS:
return inspect.ProtoHTTPS
case mgmProto.TransparentProxyProtocol_TP_PROTO_H2:
return inspect.ProtoH2
case mgmProto.TransparentProxyProtocol_TP_PROTO_H3:
return inspect.ProtoH3
case mgmProto.TransparentProxyProtocol_TP_PROTO_WEBSOCKET:
return inspect.ProtoWebSocket
case mgmProto.TransparentProxyProtocol_TP_PROTO_OTHER:
return inspect.ProtoOther
default:
return ""
}
}
func toProxyAction(a mgmProto.TransparentProxyAction) inspect.Action {
switch a {
case mgmProto.TransparentProxyAction_TP_ACTION_BLOCK:
return inspect.ActionBlock
case mgmProto.TransparentProxyAction_TP_ACTION_INSPECT:
return inspect.ActionInspect
default:
return inspect.ActionAllow
}
}
func toICAPConfig(cfg *mgmProto.TransparentProxyICAPConfig) (*inspect.ICAPConfig, error) {
icap := &inspect.ICAPConfig{
MaxConnections: int(cfg.MaxConnections),
}
if cfg.ReqmodUrl != "" {
u, err := url.Parse(cfg.ReqmodUrl)
if err != nil {
return nil, fmt.Errorf("parse ICAP reqmod URL: %w", err)
}
icap.ReqModURL = u
}
if cfg.RespmodUrl != "" {
u, err := url.Parse(cfg.RespmodUrl)
if err != nil {
return nil, fmt.Errorf("parse ICAP respmod URL: %w", err)
}
icap.RespModURL = u
}
return icap, nil
}
func parseTLSConfig(certPEM, keyPEM []byte) (*inspect.TLSConfig, error) {
block, _ := pem.Decode(certPEM)
if block == nil {
return nil, fmt.Errorf("decode CA certificate PEM")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return nil, fmt.Errorf("parse CA certificate: %w", err)
}
keyBlock, _ := pem.Decode(keyPEM)
if keyBlock == nil {
return nil, fmt.Errorf("decode CA key PEM")
}
key, err := x509.ParseECPrivateKey(keyBlock.Bytes)
if err != nil {
// Try PKCS8 as fallback
pkcs8Key, pkcs8Err := x509.ParsePKCS8PrivateKey(keyBlock.Bytes)
if pkcs8Err != nil {
return nil, fmt.Errorf("parse CA private key (tried EC and PKCS8): %w", err)
}
return &inspect.TLSConfig{CA: cert, CAKey: pkcs8Key}, nil
}
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
}
// resolveInspectionCA sets the TLS config on the proxy config using priority:
// 1. Local config file CA (InspectionCACertPath/InspectionCAKeyPath)
// 2. Management-pushed CA (already parsed in toProxyConfig)
// 3. Auto-generated self-signed CA (ephemeral, for testing)
// Local always wins to prevent a compromised management server from injecting a CA.
func (e *Engine) resolveInspectionCA(config *inspect.Config) {
// 1. Local CA from config file or env vars
certPath := e.config.InspectionCACertPath
keyPath := e.config.InspectionCAKeyPath
if certPath == "" {
certPath = os.Getenv("NB_INSPECTION_CA_CERT")
}
if keyPath == "" {
keyPath = os.Getenv("NB_INSPECTION_CA_KEY")
}
if certPath != "" && keyPath != "" {
certPEM, err := os.ReadFile(certPath)
if err != nil {
log.Errorf("read local inspection CA cert %s: %v", certPath, err)
return
}
keyPEM, err := os.ReadFile(keyPath)
if err != nil {
log.Errorf("read local inspection CA key %s: %v", keyPath, err)
return
}
tlsCfg, err := parseTLSConfig(certPEM, keyPEM)
if err != nil {
log.Errorf("parse local inspection CA: %v", err)
return
}
log.Infof("inspect: using local CA from %s", certPath)
config.TLS = tlsCfg
return
}
// 2. Management-pushed CA (already set by toProxyConfig)
if config.TLS != nil {
log.Infof("inspect: using management-pushed CA")
return
}
// 3. Auto-generate self-signed CA for testing / accept-cert UX
tlsCfg, err := generateSelfSignedCA()
if err != nil {
log.Errorf("generate self-signed inspection CA: %v", err)
return
}
log.Infof("inspect: using auto-generated self-signed CA (clients will see certificate warnings)")
config.TLS = tlsCfg
}
// generateSelfSignedCA creates an ephemeral ECDSA P-256 CA certificate.
// Clients will see certificate warnings but can choose to accept.
func generateSelfSignedCA() (*inspect.TLSConfig, error) {
key, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
if err != nil {
return nil, fmt.Errorf("generate CA key: %w", err)
}
serial, err := rand.Int(rand.Reader, new(big.Int).Lsh(big.NewInt(1), 128))
if err != nil {
return nil, fmt.Errorf("generate serial: %w", err)
}
template := &x509.Certificate{
SerialNumber: serial,
Subject: pkix.Name{
Organization: []string{"NetBird Transparent Proxy"},
CommonName: "NetBird Inspection CA (auto-generated)",
},
NotBefore: time.Now().Add(-1 * time.Hour),
NotAfter: time.Now().Add(365 * 24 * time.Hour),
KeyUsage: x509.KeyUsageCertSign | x509.KeyUsageCRLSign,
BasicConstraintsValid: true,
IsCA: true,
MaxPathLen: 0,
}
certDER, err := x509.CreateCertificate(rand.Reader, template, template, &key.PublicKey, key)
if err != nil {
return nil, fmt.Errorf("create CA certificate: %w", err)
}
cert, err := x509.ParseCertificate(certDER)
if err != nil {
return nil, fmt.Errorf("parse generated CA certificate: %w", err)
}
return &inspect.TLSConfig{CA: cert, CAKey: key}, nil
}

View File

@@ -0,0 +1,279 @@
package internal
import (
"testing"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/inspect"
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
)
func TestToProxyConfig_Basic(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Mode: mgmProto.TransparentProxyMode_TP_MODE_BUILTIN,
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_ALLOW,
RedirectSources: []string{
"10.0.0.0/24",
"192.168.1.0/24",
},
RedirectPorts: []uint32{80, 443},
Rules: []*mgmProto.TransparentProxyRule{
{
Id: "block-evil",
Domains: []string{"*.evil.com", "malware.example.com"},
Action: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
Priority: 1,
},
{
Id: "inspect-internal",
Domains: []string{"*.internal.corp"},
Networks: []string{"10.1.0.0/16"},
Ports: []uint32{443, 8443},
Action: mgmProto.TransparentProxyAction_TP_ACTION_INSPECT,
Priority: 10,
},
},
ListenPort: 8443,
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
assert.True(t, config.Enabled)
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
require.Len(t, config.RedirectSources, 2)
assert.Equal(t, "10.0.0.0/24", config.RedirectSources[0].String())
assert.Equal(t, "192.168.1.0/24", config.RedirectSources[1].String())
require.Len(t, config.RedirectPorts, 2)
assert.Equal(t, uint16(80), config.RedirectPorts[0])
assert.Equal(t, uint16(443), config.RedirectPorts[1])
require.Len(t, config.Rules, 2)
// Rule 1: block evil domains
assert.Equal(t, "block-evil", string(config.Rules[0].ID))
assert.Equal(t, inspect.ActionBlock, config.Rules[0].Action)
assert.Equal(t, 1, config.Rules[0].Priority)
require.Len(t, config.Rules[0].Domains, 2)
assert.Equal(t, "*.evil.com", config.Rules[0].Domains[0].PunycodeString())
assert.Equal(t, "malware.example.com", config.Rules[0].Domains[1].PunycodeString())
// Rule 2: inspect internal
assert.Equal(t, "inspect-internal", string(config.Rules[1].ID))
assert.Equal(t, inspect.ActionInspect, config.Rules[1].Action)
assert.Equal(t, 10, config.Rules[1].Priority)
require.Len(t, config.Rules[1].Networks, 1)
assert.Equal(t, "10.1.0.0/16", config.Rules[1].Networks[0].String())
require.Len(t, config.Rules[1].Ports, 2)
// Listen address
assert.True(t, config.ListenAddr.IsValid())
assert.Equal(t, uint16(8443), config.ListenAddr.Port())
}
func TestToProxyConfig_ExternalMode(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Mode: mgmProto.TransparentProxyMode_TP_MODE_EXTERNAL,
ExternalProxyUrl: "http://proxy.corp:8080",
DefaultAction: mgmProto.TransparentProxyAction_TP_ACTION_BLOCK,
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
assert.Equal(t, inspect.ModeExternal, config.Mode)
assert.Equal(t, inspect.ActionBlock, config.DefaultAction)
require.NotNil(t, config.ExternalURL)
assert.Equal(t, "http", config.ExternalURL.Scheme)
assert.Equal(t, "proxy.corp:8080", config.ExternalURL.Host)
}
func TestToProxyConfig_ICAP(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Icap: &mgmProto.TransparentProxyICAPConfig{
ReqmodUrl: "icap://icap-server:1344/reqmod",
RespmodUrl: "icap://icap-server:1344/respmod",
MaxConnections: 16,
},
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
require.NotNil(t, config.ICAP)
assert.Equal(t, "icap", config.ICAP.ReqModURL.Scheme)
assert.Equal(t, "icap-server:1344", config.ICAP.ReqModURL.Host)
assert.Equal(t, "/reqmod", config.ICAP.ReqModURL.Path)
assert.Equal(t, "/respmod", config.ICAP.RespModURL.Path)
assert.Equal(t, 16, config.ICAP.MaxConnections)
}
func TestToProxyConfig_Empty(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
}
config, err := toProxyConfig(cfg)
require.NoError(t, err)
assert.True(t, config.Enabled)
assert.Equal(t, inspect.ModeBuiltin, config.Mode)
assert.Equal(t, inspect.ActionAllow, config.DefaultAction)
assert.Empty(t, config.RedirectSources)
assert.Empty(t, config.RedirectPorts)
assert.Empty(t, config.Rules)
assert.Nil(t, config.ICAP)
assert.Nil(t, config.TLS)
assert.False(t, config.ListenAddr.IsValid())
}
func TestToProxyConfig_InvalidSource(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
RedirectSources: []string{"not-a-cidr"},
}
_, err := toProxyConfig(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "parse redirect source")
}
func TestToProxyConfig_InvalidNetwork(t *testing.T) {
cfg := &mgmProto.TransparentProxyConfig{
Enabled: true,
Rules: []*mgmProto.TransparentProxyRule{
{
Id: "bad",
Networks: []string{"not-a-cidr"},
},
},
}
_, err := toProxyConfig(cfg)
require.Error(t, err)
assert.Contains(t, err.Error(), "parse network")
}
func TestToProxyAction(t *testing.T) {
assert.Equal(t, inspect.ActionAllow, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_ALLOW))
assert.Equal(t, inspect.ActionBlock, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_BLOCK))
assert.Equal(t, inspect.ActionInspect, toProxyAction(mgmProto.TransparentProxyAction_TP_ACTION_INSPECT))
// Unknown defaults to allow
assert.Equal(t, inspect.ActionAllow, toProxyAction(99))
}
func TestParseUDPPacket_IPv4(t *testing.T) {
// Build a minimal IPv4/UDP packet: 20-byte IPv4 header + 8-byte UDP header + payload
packet := make([]byte, 20+8+4)
// IPv4 header: version=4, IHL=5 (20 bytes)
packet[0] = 0x45
// Protocol = UDP (17)
packet[9] = 17
// Source IP: 10.0.0.1
packet[12], packet[13], packet[14], packet[15] = 10, 0, 0, 1
// Dest IP: 192.168.1.1
packet[16], packet[17], packet[18], packet[19] = 192, 168, 1, 1
// UDP source port: 54321 (0xD431)
packet[20] = 0xD4
packet[21] = 0x31
// UDP dest port: 443 (0x01BB)
packet[22] = 0x01
packet[23] = 0xBB
// Payload
packet[28] = 0xDE
packet[29] = 0xAD
packet[30] = 0xBE
packet[31] = 0xEF
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
require.True(t, ok)
assert.Equal(t, "10.0.0.1", srcIP.String())
assert.Equal(t, "192.168.1.1", dstIP.String())
assert.Equal(t, uint16(443), dstPort)
assert.Equal(t, []byte{0xDE, 0xAD, 0xBE, 0xEF}, payload)
}
func TestParseUDPPacket_IPv6(t *testing.T) {
// Build a minimal IPv6/UDP packet: 40-byte IPv6 header + 8-byte UDP header + payload
packet := make([]byte, 40+8+4)
// Version = 6 (0x60 in high nibble)
packet[0] = 0x60
// Payload length: 8 (UDP header) + 4 (payload)
packet[4] = 0
packet[5] = 12
// Next header: UDP (17)
packet[6] = 17
// Source: 2001:db8::1
packet[8] = 0x20
packet[9] = 0x01
packet[10] = 0x0d
packet[11] = 0xb8
packet[23] = 0x01
// Dest: 2001:db8::2
packet[24] = 0x20
packet[25] = 0x01
packet[26] = 0x0d
packet[27] = 0xb8
packet[39] = 0x02
// UDP source port: 54321 (0xD431)
packet[40] = 0xD4
packet[41] = 0x31
// UDP dest port: 443 (0x01BB)
packet[42] = 0x01
packet[43] = 0xBB
// Payload
packet[48] = 0xCA
packet[49] = 0xFE
packet[50] = 0xBA
packet[51] = 0xBE
srcIP, dstIP, dstPort, payload, ok := parseUDPPacket(packet)
require.True(t, ok)
assert.Equal(t, "2001:db8::1", srcIP.String())
assert.Equal(t, "2001:db8::2", dstIP.String())
assert.Equal(t, uint16(443), dstPort)
assert.Equal(t, []byte{0xCA, 0xFE, 0xBA, 0xBE}, payload)
}
func TestParseUDPPacket_TooShort(t *testing.T) {
_, _, _, _, ok := parseUDPPacket(nil)
assert.False(t, ok)
_, _, _, _, ok = parseUDPPacket([]byte{0x45, 0x00})
assert.False(t, ok)
}
func TestParseUDPPacket_IPv6ExtensionHeader(t *testing.T) {
// IPv6 with next header != UDP should be rejected
packet := make([]byte, 48)
packet[0] = 0x60
packet[6] = 6 // TCP, not UDP
_, _, _, _, ok := parseUDPPacket(packet)
assert.False(t, ok, "should reject IPv6 packets with non-UDP next header")
}
func TestParseUDPPacket_IPv4MappedIPv6(t *testing.T) {
// IPv4 packet with normal addresses should Unmap correctly
packet := make([]byte, 28)
packet[0] = 0x45
packet[9] = 17
packet[12], packet[13], packet[14], packet[15] = 127, 0, 0, 1
packet[16], packet[17], packet[18], packet[19] = 10, 0, 0, 1
packet[22] = 0x01
packet[23] = 0xBB
srcIP, dstIP, _, _, ok := parseUDPPacket(packet)
require.True(t, ok)
assert.True(t, srcIP.Is4(), "should be plain IPv4, not mapped")
assert.True(t, dstIP.Is4(), "should be plain IPv4, not mapped")
}

View File

@@ -3,6 +3,7 @@ package activity
import (
"net"
"net/netip"
"runtime"
"testing"
"time"
@@ -17,6 +18,10 @@ import (
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
)
func isBindListenerPlatform() bool {
return runtime.GOOS == "windows" || runtime.GOOS == "js"
}
// mockEndpointManager implements device.EndpointManager for testing
type mockEndpointManager struct {
endpoints map[netip.Addr]net.Conn
@@ -176,6 +181,10 @@ func TestBindListener_Close(t *testing.T) {
}
func TestManager_BindMode(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -217,6 +226,10 @@ func TestManager_BindMode(t *testing.T) {
}
func TestManager_BindMode_MultiplePeers(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}

View File

@@ -4,12 +4,14 @@ import (
"errors"
"net"
"net/netip"
"runtime"
"sync"
"time"
log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -73,6 +75,16 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg)
}
// BindListener is used on Windows, JS, and netstack platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// - Netstack: Allows multiple instances on the same host without port conflicts.
// BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
return NewUDPListener(m.wgIface, peerCfg)
}
provider, ok := m.wgIface.(bindProvider)
if !ok {
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")

View File

@@ -6,6 +6,7 @@ import (
"time"
log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
@@ -90,8 +91,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
m.routesMu.Lock()
defer m.routesMu.Unlock()
clear(m.peerToHAGroups)
clear(m.haGroupToPeers)
maps.Clear(m.peerToHAGroups)
maps.Clear(m.haGroupToPeers)
for haUniqueID, routes := range haMap {
var peers []string

Some files were not shown because too many files have changed in this diff Show More