mirror of
https://github.com/netbirdio/netbird.git
synced 2026-05-06 00:56:39 +00:00
Compare commits
2 Commits
fix/debug-
...
revert-dns
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
faf3559ea7 | ||
|
|
dc8c2edf50 |
307
.github/workflows/release.yml
vendored
307
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.4"
|
SIGN_PIPE_VER: "v0.1.2"
|
||||||
GORELEASER_VER: "v2.14.3"
|
GORELEASER_VER: "v2.14.3"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -114,13 +114,7 @@ jobs:
|
|||||||
retention-days: 30
|
retention-days: 30
|
||||||
|
|
||||||
release:
|
release:
|
||||||
runs-on: ubuntu-24.04-8-core
|
runs-on: ubuntu-latest-m
|
||||||
outputs:
|
|
||||||
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
|
|
||||||
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
|
|
||||||
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
|
|
||||||
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
|
|
||||||
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
|
|
||||||
env:
|
env:
|
||||||
flags: ""
|
flags: ""
|
||||||
steps:
|
steps:
|
||||||
@@ -219,13 +213,10 @@ jobs:
|
|||||||
if: always()
|
if: always()
|
||||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||||
- name: Tag and push images (amd64 only)
|
- name: Tag and push images (amd64 only)
|
||||||
id: tag_and_push_images
|
|
||||||
if: |
|
if: |
|
||||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
||||||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
||||||
run: |
|
run: |
|
||||||
set -euo pipefail
|
|
||||||
|
|
||||||
resolve_tags() {
|
resolve_tags() {
|
||||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||||
echo "pr-${{ github.event.pull_request.number }}"
|
echo "pr-${{ github.event.pull_request.number }}"
|
||||||
@@ -234,17 +225,6 @@ jobs:
|
|||||||
fi
|
fi
|
||||||
}
|
}
|
||||||
|
|
||||||
ghcr_package_url() {
|
|
||||||
local image="$1" package encoded_package
|
|
||||||
package="${image#ghcr.io/}"
|
|
||||||
package="${package#*/}"
|
|
||||||
package="${package%%:*}"
|
|
||||||
encoded_package="${package//\//%2F}"
|
|
||||||
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
|
|
||||||
}
|
|
||||||
|
|
||||||
image_refs=()
|
|
||||||
|
|
||||||
tag_and_push() {
|
tag_and_push() {
|
||||||
local src="$1" img_name tag dst
|
local src="$1" img_name tag dst
|
||||||
img_name="${src%%:*}"
|
img_name="${src%%:*}"
|
||||||
@@ -253,56 +233,35 @@ jobs:
|
|||||||
echo "Tagging ${src} -> ${dst}"
|
echo "Tagging ${src} -> ${dst}"
|
||||||
docker tag "$src" "$dst"
|
docker tag "$src" "$dst"
|
||||||
docker push "$dst"
|
docker push "$dst"
|
||||||
image_refs+=("$dst")
|
|
||||||
done
|
done
|
||||||
}
|
}
|
||||||
|
|
||||||
cat > /tmp/goreleaser-artifacts.json <<'JSON'
|
export -f tag_and_push resolve_tags
|
||||||
${{ steps.goreleaser.outputs.artifacts }}
|
|
||||||
JSON
|
|
||||||
|
|
||||||
mapfile -t src_images < <(
|
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json
|
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||||
)
|
grep '^ghcr.io/' | while read -r SRC; do
|
||||||
|
tag_and_push "$SRC"
|
||||||
for src in "${src_images[@]}"; do
|
done
|
||||||
tag_and_push "$src"
|
|
||||||
done
|
|
||||||
|
|
||||||
{
|
|
||||||
echo "images_markdown<<EOF"
|
|
||||||
if [[ ${#image_refs[@]} -eq 0 ]]; then
|
|
||||||
echo "_No GHCR images were pushed._"
|
|
||||||
else
|
|
||||||
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
|
|
||||||
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
|
|
||||||
done
|
|
||||||
fi
|
|
||||||
echo "EOF"
|
|
||||||
} >> "$GITHUB_OUTPUT"
|
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
id: upload_release
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: dist/
|
path: dist/
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: upload linux packages
|
- name: upload linux packages
|
||||||
id: upload_linux_packages
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-packages
|
name: linux-packages
|
||||||
path: dist/netbird_linux**
|
path: dist/netbird_linux**
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: upload windows packages
|
- name: upload windows packages
|
||||||
id: upload_windows_packages
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-packages
|
name: windows-packages
|
||||||
path: dist/netbird_windows**
|
path: dist/netbird_windows**
|
||||||
retention-days: 7
|
retention-days: 7
|
||||||
- name: upload macos packages
|
- name: upload macos packages
|
||||||
id: upload_macos_packages
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: macos-packages
|
name: macos-packages
|
||||||
@@ -311,8 +270,6 @@ jobs:
|
|||||||
|
|
||||||
release_ui:
|
release_ui:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
outputs:
|
|
||||||
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
|
|
||||||
steps:
|
steps:
|
||||||
- name: Parse semver string
|
- name: Parse semver string
|
||||||
id: semver_parser
|
id: semver_parser
|
||||||
@@ -403,7 +360,6 @@ jobs:
|
|||||||
if: always()
|
if: always()
|
||||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
id: upload_release_ui
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release-ui
|
name: release-ui
|
||||||
@@ -412,8 +368,6 @@ jobs:
|
|||||||
|
|
||||||
release_ui_darwin:
|
release_ui_darwin:
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
outputs:
|
|
||||||
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
|
|
||||||
steps:
|
steps:
|
||||||
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
|
||||||
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
run: echo "flags=--snapshot" >> $GITHUB_ENV
|
||||||
@@ -448,258 +402,15 @@ jobs:
|
|||||||
env:
|
env:
|
||||||
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
id: upload_release_ui_darwin
|
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: release-ui-darwin
|
name: release-ui-darwin
|
||||||
path: dist/
|
path: dist/
|
||||||
retention-days: 3
|
retention-days: 3
|
||||||
|
|
||||||
test_windows_installer:
|
|
||||||
name: "Windows Installer / Build Test"
|
|
||||||
runs-on: windows-2022
|
|
||||||
needs: [release, release_ui]
|
|
||||||
strategy:
|
|
||||||
fail-fast: false
|
|
||||||
matrix:
|
|
||||||
include:
|
|
||||||
- arch: amd64
|
|
||||||
wintun_arch: amd64
|
|
||||||
- arch: arm64
|
|
||||||
wintun_arch: arm64
|
|
||||||
defaults:
|
|
||||||
run:
|
|
||||||
shell: powershell
|
|
||||||
env:
|
|
||||||
PackageWorkdir: netbird_windows_${{ matrix.arch }}
|
|
||||||
downloadPath: '${{ github.workspace }}\temp'
|
|
||||||
steps:
|
|
||||||
- name: Parse semver string
|
|
||||||
id: semver_parser
|
|
||||||
uses: booxmedialtd/ws-action-parse-semver@v1
|
|
||||||
with:
|
|
||||||
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
|
|
||||||
version_extractor_regex: '\/v(.*)$'
|
|
||||||
|
|
||||||
- name: Checkout
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Add 7-Zip to PATH
|
|
||||||
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
|
|
||||||
|
|
||||||
- name: Download release artifacts
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: release
|
|
||||||
path: release
|
|
||||||
|
|
||||||
- name: Download UI release artifacts
|
|
||||||
uses: actions/download-artifact@v4
|
|
||||||
with:
|
|
||||||
name: release-ui
|
|
||||||
path: release-ui
|
|
||||||
|
|
||||||
- name: Stage binaries into dist
|
|
||||||
run: |
|
|
||||||
$workdir = "dist\${{ env.PackageWorkdir }}"
|
|
||||||
New-Item -ItemType Directory -Force -Path $workdir | Out-Null
|
|
||||||
$client = Get-ChildItem -Recurse -Path release -Filter "netbird_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
|
|
||||||
$ui = Get-ChildItem -Recurse -Path release-ui -Filter "netbird-ui-windows_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
|
|
||||||
if (-not $client) { Write-Host "::error::client tarball not found for ${{ matrix.arch }}"; exit 1 }
|
|
||||||
if (-not $ui) { Write-Host "::error::ui tarball not found for ${{ matrix.arch }}"; exit 1 }
|
|
||||||
Write-Host "Client: $($client.FullName)"
|
|
||||||
Write-Host "UI: $($ui.FullName)"
|
|
||||||
tar -zvxf $client.FullName -C $workdir
|
|
||||||
tar -zvxf $ui.FullName -C $workdir
|
|
||||||
Get-ChildItem $workdir
|
|
||||||
|
|
||||||
- name: Download wintun
|
|
||||||
uses: carlosperate/download-file-action@v2
|
|
||||||
id: download-wintun
|
|
||||||
with:
|
|
||||||
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
|
|
||||||
file-name: wintun.zip
|
|
||||||
location: ${{ env.downloadPath }}
|
|
||||||
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
|
|
||||||
|
|
||||||
- name: Decompress wintun files
|
|
||||||
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
|
|
||||||
|
|
||||||
- name: Move wintun.dll into dist
|
|
||||||
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
|
||||||
|
|
||||||
- name: Download Mesa3D (amd64 only)
|
|
||||||
uses: carlosperate/download-file-action@v2
|
|
||||||
id: download-mesa3d
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
with:
|
|
||||||
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
|
|
||||||
file-name: mesa3d.7z
|
|
||||||
location: ${{ env.downloadPath }}
|
|
||||||
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
|
|
||||||
|
|
||||||
- name: Extract Mesa3D driver (amd64 only)
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
|
|
||||||
|
|
||||||
- name: Move opengl32.dll into dist (amd64 only)
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
|
|
||||||
|
|
||||||
- name: Download EnVar plugin for NSIS
|
|
||||||
uses: carlosperate/download-file-action@v2
|
|
||||||
with:
|
|
||||||
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
|
|
||||||
file-name: envar_plugin.zip
|
|
||||||
location: ${{ github.workspace }}
|
|
||||||
|
|
||||||
- name: Extract EnVar plugin
|
|
||||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
|
|
||||||
|
|
||||||
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
|
|
||||||
uses: carlosperate/download-file-action@v2
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
with:
|
|
||||||
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
|
|
||||||
file-name: ShellExecAsUser_amd64-Unicode.7z
|
|
||||||
location: ${{ github.workspace }}
|
|
||||||
|
|
||||||
- name: Extract ShellExecAsUser plugin (amd64 only)
|
|
||||||
if: matrix.arch == 'amd64'
|
|
||||||
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
|
|
||||||
|
|
||||||
- name: Build NSIS installer
|
|
||||||
uses: joncloud/makensis-action@v3.3
|
|
||||||
with:
|
|
||||||
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
|
|
||||||
script-file: client/installer.nsis
|
|
||||||
arguments: "/V4 /DARCH=${{ matrix.arch }}"
|
|
||||||
env:
|
|
||||||
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
|
|
||||||
|
|
||||||
- name: Rename NSIS installer
|
|
||||||
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
|
|
||||||
|
|
||||||
- name: Install WiX
|
|
||||||
run: |
|
|
||||||
dotnet tool install --global wix --version 6.0.2
|
|
||||||
wix extension add WixToolset.Util.wixext/6.0.2
|
|
||||||
|
|
||||||
- name: Build MSI installer
|
|
||||||
env:
|
|
||||||
NETBIRD_VERSION: "${{ steps.semver_parser.outputs.fullversion }}"
|
|
||||||
run: wix build -arch ${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -ext WixToolset.Util.wixext -o netbird_installer_test_windows_${{ matrix.arch }}.msi .\client\netbird.wxs -d ProcessorArchitecture=${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -d ArchSuffix=${{ matrix.arch }}
|
|
||||||
|
|
||||||
- name: Upload installer artifacts
|
|
||||||
if: always()
|
|
||||||
uses: actions/upload-artifact@v4
|
|
||||||
with:
|
|
||||||
name: windows-installer-test-${{ matrix.arch }}
|
|
||||||
path: |
|
|
||||||
netbird_installer_test_windows_${{ matrix.arch }}.exe
|
|
||||||
netbird_installer_test_windows_${{ matrix.arch }}.msi
|
|
||||||
retention-days: 3
|
|
||||||
|
|
||||||
comment_release_artifacts:
|
|
||||||
name: Comment release artifacts
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
needs: [release, release_ui, release_ui_darwin]
|
|
||||||
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
|
|
||||||
permissions:
|
|
||||||
contents: read
|
|
||||||
issues: write
|
|
||||||
pull-requests: write
|
|
||||||
steps:
|
|
||||||
- name: Create or update PR comment
|
|
||||||
uses: actions/github-script@v7
|
|
||||||
env:
|
|
||||||
RELEASE_RESULT: ${{ needs.release.result }}
|
|
||||||
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
|
|
||||||
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
|
|
||||||
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
|
|
||||||
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
|
|
||||||
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
|
|
||||||
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
|
|
||||||
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
|
|
||||||
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
|
|
||||||
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
|
|
||||||
with:
|
|
||||||
github-token: ${{ secrets.GITHUB_TOKEN }}
|
|
||||||
script: |
|
|
||||||
const marker = '<!-- netbird-release-artifacts -->';
|
|
||||||
const { owner, repo } = context.repo;
|
|
||||||
const issue_number = context.payload.pull_request.number;
|
|
||||||
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
|
|
||||||
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
|
|
||||||
|
|
||||||
const artifactCell = (url, result) => {
|
|
||||||
if (url) return `[Download](${url})`;
|
|
||||||
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
|
|
||||||
};
|
|
||||||
|
|
||||||
const artifacts = [
|
|
||||||
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
|
||||||
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
|
||||||
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
|
||||||
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
|
|
||||||
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
|
|
||||||
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
|
|
||||||
];
|
|
||||||
|
|
||||||
const artifactRows = artifacts
|
|
||||||
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
|
|
||||||
.join('\n');
|
|
||||||
|
|
||||||
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
|
|
||||||
|
|
||||||
const body = [
|
|
||||||
marker,
|
|
||||||
'## Release artifacts',
|
|
||||||
'',
|
|
||||||
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
|
|
||||||
'',
|
|
||||||
'| Artifact | Link |',
|
|
||||||
'| --- | --- |',
|
|
||||||
artifactRows,
|
|
||||||
'',
|
|
||||||
'### GHCR images (amd64)',
|
|
||||||
ghcrImages,
|
|
||||||
'',
|
|
||||||
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
|
|
||||||
].join('\n');
|
|
||||||
|
|
||||||
const comments = await github.paginate(github.rest.issues.listComments, {
|
|
||||||
owner,
|
|
||||||
repo,
|
|
||||||
issue_number,
|
|
||||||
per_page: 100,
|
|
||||||
});
|
|
||||||
|
|
||||||
const previous = comments.find(comment =>
|
|
||||||
comment.user?.type === 'Bot' && comment.body?.includes(marker)
|
|
||||||
);
|
|
||||||
|
|
||||||
if (previous) {
|
|
||||||
await github.rest.issues.updateComment({
|
|
||||||
owner,
|
|
||||||
repo,
|
|
||||||
comment_id: previous.id,
|
|
||||||
body,
|
|
||||||
});
|
|
||||||
core.info(`Updated release artifacts comment ${previous.id}`);
|
|
||||||
} else {
|
|
||||||
const { data } = await github.rest.issues.createComment({
|
|
||||||
owner,
|
|
||||||
repo,
|
|
||||||
issue_number,
|
|
||||||
body,
|
|
||||||
});
|
|
||||||
core.info(`Created release artifacts comment ${data.id}`);
|
|
||||||
}
|
|
||||||
|
|
||||||
trigger_signer:
|
trigger_signer:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
needs: [release, release_ui, release_ui_darwin, test_windows_installer]
|
needs: [release, release_ui, release_ui_darwin]
|
||||||
if: startsWith(github.ref, 'refs/tags/')
|
if: startsWith(github.ref, 'refs/tags/')
|
||||||
steps:
|
steps:
|
||||||
- name: Trigger binaries sign pipelines
|
- name: Trigger binaries sign pipelines
|
||||||
|
|||||||
28
.github/workflows/sync-tag.yml
vendored
28
.github/workflows/sync-tag.yml
vendored
@@ -9,8 +9,6 @@ concurrency:
|
|||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short
|
|
||||||
# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref.
|
|
||||||
jobs:
|
jobs:
|
||||||
trigger_sync_tag:
|
trigger_sync_tag:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
@@ -22,30 +20,4 @@ jobs:
|
|||||||
ref: main
|
ref: main
|
||||||
repo: ${{ secrets.UPSTREAM_REPO }}
|
repo: ${{ secrets.UPSTREAM_REPO }}
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
||||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
|
||||||
|
|
||||||
trigger_android_bump:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
|
||||||
steps:
|
|
||||||
- name: Trigger android-client submodule bump
|
|
||||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
|
||||||
with:
|
|
||||||
workflow: bump-netbird.yml
|
|
||||||
ref: main
|
|
||||||
repo: netbirdio/android-client
|
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
|
||||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
|
||||||
|
|
||||||
trigger_ios_bump:
|
|
||||||
runs-on: ubuntu-latest
|
|
||||||
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
|
|
||||||
steps:
|
|
||||||
- name: Trigger ios-client submodule bump
|
|
||||||
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
|
|
||||||
with:
|
|
||||||
workflow: bump-netbird.yml
|
|
||||||
ref: main
|
|
||||||
repo: netbirdio/ios-client
|
|
||||||
token: ${{ secrets.NC_GITHUB_TOKEN }}
|
|
||||||
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
inputs: '{ "tag": "${{ github.ref_name }}" }'
|
||||||
@@ -58,11 +58,6 @@ linters:
|
|||||||
govet:
|
govet:
|
||||||
enable:
|
enable:
|
||||||
- nilness
|
- nilness
|
||||||
disable:
|
|
||||||
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
|
|
||||||
# directives but cannot perform the rewrite due to generic type
|
|
||||||
# parameter inference limitations in the Go inliner.
|
|
||||||
- inline
|
|
||||||
enable-all: false
|
enable-all: false
|
||||||
revive:
|
revive:
|
||||||
rules:
|
rules:
|
||||||
|
|||||||
@@ -17,7 +17,6 @@ ENV \
|
|||||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||||
NB_ENABLE_CAPTURE="false" \
|
|
||||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||||
|
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||||
|
|||||||
@@ -23,7 +23,6 @@ ENV \
|
|||||||
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
||||||
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
||||||
NB_DISABLE_DNS="true" \
|
NB_DISABLE_DNS="true" \
|
||||||
NB_ENABLE_CAPTURE="false" \
|
|
||||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||||
|
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||||
|
|||||||
@@ -1,196 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"os/signal"
|
|
||||||
"path/filepath"
|
|
||||||
"strings"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
"github.com/netbirdio/netbird/util/capture"
|
|
||||||
)
|
|
||||||
|
|
||||||
var captureCmd = &cobra.Command{
|
|
||||||
Use: "capture",
|
|
||||||
Short: "Capture packets on the WireGuard interface",
|
|
||||||
Long: `Captures decrypted packets flowing through the WireGuard interface.
|
|
||||||
|
|
||||||
Default output is human-readable text. Use --pcap or --output for pcap binary.
|
|
||||||
Requires --enable-capture to be set at service install or reconfigure time.
|
|
||||||
|
|
||||||
Examples:
|
|
||||||
netbird debug capture
|
|
||||||
netbird debug capture host 100.64.0.1 and port 443
|
|
||||||
netbird debug capture tcp
|
|
||||||
netbird debug capture icmp
|
|
||||||
netbird debug capture src host 10.0.0.1 and dst port 80
|
|
||||||
netbird debug capture -o capture.pcap
|
|
||||||
netbird debug capture --pcap | tshark -r -
|
|
||||||
netbird debug capture --pcap | tcpdump -r - -n`,
|
|
||||||
Args: cobra.ArbitraryArgs,
|
|
||||||
RunE: runCapture,
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
debugCmd.AddCommand(captureCmd)
|
|
||||||
|
|
||||||
captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
|
|
||||||
captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length")
|
|
||||||
captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)")
|
|
||||||
captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)")
|
|
||||||
captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)")
|
|
||||||
captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
|
|
||||||
}
|
|
||||||
|
|
||||||
func runCapture(cmd *cobra.Command, args []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
if err := conn.Close(); err != nil {
|
|
||||||
cmd.PrintErrf(errCloseConnection, err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
|
|
||||||
req, err := buildCaptureRequest(cmd, args)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
stream, err := client.StartCapture(ctx, req)
|
|
||||||
if err != nil {
|
|
||||||
return handleCaptureError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// First Recv is the empty acceptance message from the server. If the
|
|
||||||
// device is unavailable (kernel WG, not connected, capture disabled),
|
|
||||||
// the server returns an error instead.
|
|
||||||
if _, err := stream.Recv(); err != nil {
|
|
||||||
return handleCaptureError(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
out, cleanup, err := captureOutput(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if req.TextOutput {
|
|
||||||
cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n")
|
|
||||||
} else {
|
|
||||||
cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n")
|
|
||||||
}
|
|
||||||
|
|
||||||
streamErr := streamCapture(ctx, cmd, stream, out)
|
|
||||||
cleanupErr := cleanup()
|
|
||||||
if streamErr != nil {
|
|
||||||
return streamErr
|
|
||||||
}
|
|
||||||
return cleanupErr
|
|
||||||
}
|
|
||||||
|
|
||||||
func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) {
|
|
||||||
req := &proto.StartCaptureRequest{}
|
|
||||||
|
|
||||||
if len(args) > 0 {
|
|
||||||
expr := strings.Join(args, " ")
|
|
||||||
if _, err := capture.ParseFilter(expr); err != nil {
|
|
||||||
return nil, fmt.Errorf("invalid filter: %w", err)
|
|
||||||
}
|
|
||||||
req.FilterExpr = expr
|
|
||||||
}
|
|
||||||
|
|
||||||
if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 {
|
|
||||||
req.SnapLen = snap
|
|
||||||
}
|
|
||||||
if d, _ := cmd.Flags().GetDuration("duration"); d != 0 {
|
|
||||||
if d < 0 {
|
|
||||||
return nil, fmt.Errorf("duration must not be negative")
|
|
||||||
}
|
|
||||||
req.Duration = durationpb.New(d)
|
|
||||||
}
|
|
||||||
req.Verbose, _ = cmd.Flags().GetBool("verbose")
|
|
||||||
req.Ascii, _ = cmd.Flags().GetBool("ascii")
|
|
||||||
|
|
||||||
outPath, _ := cmd.Flags().GetString("output")
|
|
||||||
forcePcap, _ := cmd.Flags().GetBool("pcap")
|
|
||||||
req.TextOutput = !forcePcap && outPath == ""
|
|
||||||
|
|
||||||
return req, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error {
|
|
||||||
for {
|
|
||||||
pkt, err := stream.Recv()
|
|
||||||
if err != nil {
|
|
||||||
if ctx.Err() != nil {
|
|
||||||
cmd.PrintErrf("\nCapture stopped.\n")
|
|
||||||
return nil //nolint:nilerr // user interrupted
|
|
||||||
}
|
|
||||||
if err == io.EOF {
|
|
||||||
cmd.PrintErrf("\nCapture finished.\n")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return handleCaptureError(err)
|
|
||||||
}
|
|
||||||
if _, err := out.Write(pkt.GetData()); err != nil {
|
|
||||||
return fmt.Errorf("write output: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// captureOutput returns the writer for capture data and a cleanup function
|
|
||||||
// that finalizes the file. Errors from the cleanup must be propagated.
|
|
||||||
func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) {
|
|
||||||
outPath, _ := cmd.Flags().GetString("output")
|
|
||||||
if outPath == "" {
|
|
||||||
return os.Stdout, func() error { return nil }, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
|
|
||||||
if err != nil {
|
|
||||||
return nil, nil, fmt.Errorf("create output file: %w", err)
|
|
||||||
}
|
|
||||||
tmpPath := f.Name()
|
|
||||||
return f, func() error {
|
|
||||||
var merr *multierror.Error
|
|
||||||
if err := f.Close(); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err))
|
|
||||||
}
|
|
||||||
fi, statErr := os.Stat(tmpPath)
|
|
||||||
if statErr != nil || fi.Size() == 0 {
|
|
||||||
if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr))
|
|
||||||
}
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
if err := os.Rename(tmpPath, outPath); err != nil {
|
|
||||||
merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err))
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}
|
|
||||||
cmd.PrintErrf("Wrote %s\n", outPath)
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func handleCaptureError(err error) error {
|
|
||||||
if s, ok := status.FromError(err); ok {
|
|
||||||
return fmt.Errorf("%s", s.Message())
|
|
||||||
}
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
@@ -240,50 +239,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
captureStarted := false
|
|
||||||
if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture {
|
|
||||||
captureTimeout := duration + 30*time.Second
|
|
||||||
const maxBundleCapture = 10 * time.Minute
|
|
||||||
if captureTimeout > maxBundleCapture {
|
|
||||||
captureTimeout = maxBundleCapture
|
|
||||||
}
|
|
||||||
_, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{
|
|
||||||
Timeout: durationpb.New(captureTimeout),
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message())
|
|
||||||
} else {
|
|
||||||
captureStarted = true
|
|
||||||
cmd.Println("Packet capture started.")
|
|
||||||
// Safety: always stop on exit, even if the normal stop below runs too.
|
|
||||||
defer func() {
|
|
||||||
if captureStarted {
|
|
||||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
|
||||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||||
return waitErr
|
return waitErr
|
||||||
}
|
}
|
||||||
cmd.Println("\nDuration completed")
|
cmd.Println("\nDuration completed")
|
||||||
|
|
||||||
if captureStarted {
|
|
||||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
|
||||||
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
|
|
||||||
} else {
|
|
||||||
captureStarted = false
|
|
||||||
cmd.Println("Packet capture stopped.")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if cpuProfilingStarted {
|
if cpuProfilingStarted {
|
||||||
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
|
||||||
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
|
||||||
@@ -456,5 +416,4 @@ func init() {
|
|||||||
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
|
||||||
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
|
||||||
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||||
forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"golang.org/x/term"
|
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
@@ -24,7 +23,6 @@ import (
|
|||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
|
||||||
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||||
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
|
||||||
}
|
}
|
||||||
@@ -258,7 +256,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -326,7 +324,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||||
|
|
||||||
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -336,7 +334,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||||
var codeMsg string
|
var codeMsg string
|
||||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
@@ -350,12 +348,6 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
verificationURIComplete + " " + codeMsg)
|
verificationURIComplete + " " + codeMsg)
|
||||||
}
|
}
|
||||||
|
|
||||||
if showQR {
|
|
||||||
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
|
|
||||||
printQRCode(f, verificationURIComplete)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
if !noBrowser {
|
if !noBrowser {
|
||||||
|
|||||||
@@ -1,25 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/mdp/qrterminal/v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
// printQRCode prints a QR code for the given URL to the writer.
|
|
||||||
// Called only when the user explicitly requests QR output via --qr.
|
|
||||||
func printQRCode(w io.Writer, url string) {
|
|
||||||
if url == "" {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
qrterminal.GenerateWithConfig(url, qrterminal.Config{
|
|
||||||
Level: qrterminal.M,
|
|
||||||
Writer: w,
|
|
||||||
HalfBlocks: true,
|
|
||||||
BlackChar: qrterminal.BLACK_BLACK,
|
|
||||||
WhiteChar: qrterminal.WHITE_WHITE,
|
|
||||||
BlackWhiteChar: qrterminal.BLACK_WHITE,
|
|
||||||
WhiteBlackChar: qrterminal.WHITE_BLACK,
|
|
||||||
QuietZone: qrterminal.QUIET_ZONE,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
@@ -1,26 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"bytes"
|
|
||||||
"testing"
|
|
||||||
)
|
|
||||||
|
|
||||||
func TestPrintQRCode_EmptyURL(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
printQRCode(&buf, "")
|
|
||||||
|
|
||||||
if buf.Len() != 0 {
|
|
||||||
t.Error("expected no output for empty URL")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestPrintQRCode_WritesOutput(t *testing.T) {
|
|
||||||
var buf bytes.Buffer
|
|
||||||
|
|
||||||
printQRCode(&buf, "https://example.com/auth")
|
|
||||||
|
|
||||||
if buf.Len() == 0 {
|
|
||||||
t.Error("expected QR code output for non-empty URL")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -75,7 +75,6 @@ var (
|
|||||||
mtu uint16
|
mtu uint16
|
||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
captureEnabled bool
|
|
||||||
networksDisabled bool
|
networksDisabled bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ func init() {
|
|||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
|
|
||||||
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
|
||||||
|
|
||||||
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled)
|
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
|
||||||
if err := serverInstance.Start(); err != nil {
|
if err := serverInstance.Start(); err != nil {
|
||||||
log.Fatalf("failed to start daemon: %v", err)
|
log.Fatalf("failed to start daemon: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -59,10 +59,6 @@ func buildServiceArguments() []string {
|
|||||||
args = append(args, "--disable-update-settings")
|
args = append(args, "--disable-update-settings")
|
||||||
}
|
}
|
||||||
|
|
||||||
if captureEnabled {
|
|
||||||
args = append(args, "--enable-capture")
|
|
||||||
}
|
|
||||||
|
|
||||||
if networksDisabled {
|
if networksDisabled {
|
||||||
args = append(args, "--disable-networks")
|
args = append(args, "--disable-networks")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ type serviceParams struct {
|
|||||||
LogFiles []string `json:"log_files,omitempty"`
|
LogFiles []string `json:"log_files,omitempty"`
|
||||||
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||||
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||||
EnableCapture bool `json:"enable_capture,omitempty"`
|
|
||||||
DisableNetworks bool `json:"disable_networks,omitempty"`
|
DisableNetworks bool `json:"disable_networks,omitempty"`
|
||||||
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||||
}
|
}
|
||||||
@@ -80,7 +79,6 @@ func currentServiceParams() *serviceParams {
|
|||||||
LogFiles: logFiles,
|
LogFiles: logFiles,
|
||||||
DisableProfiles: profilesDisabled,
|
DisableProfiles: profilesDisabled,
|
||||||
DisableUpdateSettings: updateSettingsDisabled,
|
DisableUpdateSettings: updateSettingsDisabled,
|
||||||
EnableCapture: captureEnabled,
|
|
||||||
DisableNetworks: networksDisabled,
|
DisableNetworks: networksDisabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -146,10 +144,6 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
|||||||
updateSettingsDisabled = params.DisableUpdateSettings
|
updateSettingsDisabled = params.DisableUpdateSettings
|
||||||
}
|
}
|
||||||
|
|
||||||
if !serviceCmd.PersistentFlags().Changed("enable-capture") {
|
|
||||||
captureEnabled = params.EnableCapture
|
|
||||||
}
|
|
||||||
|
|
||||||
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
if !serviceCmd.PersistentFlags().Changed("disable-networks") {
|
||||||
networksDisabled = params.DisableNetworks
|
networksDisabled = params.DisableNetworks
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -535,7 +535,6 @@ func fieldToGlobalVar(field string) string {
|
|||||||
"LogFiles": "logFiles",
|
"LogFiles": "logFiles",
|
||||||
"DisableProfiles": "profilesDisabled",
|
"DisableProfiles": "profilesDisabled",
|
||||||
"DisableUpdateSettings": "updateSettingsDisabled",
|
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||||
"EnableCapture": "captureEnabled",
|
|
||||||
"DisableNetworks": "networksDisabled",
|
"DisableNetworks": "networksDisabled",
|
||||||
"ServiceEnvVars": "serviceEnvVars",
|
"ServiceEnvVars": "serviceEnvVars",
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil)
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -160,7 +160,7 @@ func startClientDaemon(
|
|||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
|
|
||||||
server := client.New(ctx,
|
server := client.New(ctx,
|
||||||
"", "", false, false, false, false)
|
"", "", false, false, false)
|
||||||
if err := server.Start(); err != nil {
|
if err := server.Start(); err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,9 +39,6 @@ const (
|
|||||||
noBrowserFlag = "no-browser"
|
noBrowserFlag = "no-browser"
|
||||||
noBrowserDesc = "do not open the browser for SSO login"
|
noBrowserDesc = "do not open the browser for SSO login"
|
||||||
|
|
||||||
showQRFlag = "qr"
|
|
||||||
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
|
|
||||||
|
|
||||||
profileNameFlag = "profile"
|
profileNameFlag = "profile"
|
||||||
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
|
||||||
)
|
)
|
||||||
@@ -51,7 +48,6 @@ var (
|
|||||||
dnsLabels []string
|
dnsLabels []string
|
||||||
dnsLabelsValidated domain.List
|
dnsLabelsValidated domain.List
|
||||||
noBrowser bool
|
noBrowser bool
|
||||||
showQR bool
|
|
||||||
profileName string
|
profileName string
|
||||||
configPath string
|
configPath string
|
||||||
|
|
||||||
@@ -84,7 +80,6 @@ func init() {
|
|||||||
)
|
)
|
||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
|
|
||||||
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
|
||||||
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")
|
||||||
|
|
||||||
|
|||||||
@@ -1,65 +0,0 @@
|
|||||||
package embed
|
|
||||||
|
|
||||||
import (
|
|
||||||
"io"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/util/capture"
|
|
||||||
)
|
|
||||||
|
|
||||||
// CaptureOptions configures a packet capture session.
|
|
||||||
type CaptureOptions struct {
|
|
||||||
// Output receives pcap-formatted data. Nil disables pcap output.
|
|
||||||
Output io.Writer
|
|
||||||
// TextOutput receives human-readable packet summaries. Nil disables text output.
|
|
||||||
TextOutput io.Writer
|
|
||||||
// Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443").
|
|
||||||
// Empty captures all packets.
|
|
||||||
Filter string
|
|
||||||
// Verbose adds seq/ack, TTL, window, and total length to text output.
|
|
||||||
Verbose bool
|
|
||||||
// ASCII dumps transport payload as printable ASCII after each packet line.
|
|
||||||
ASCII bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// CaptureStats reports capture session counters.
|
|
||||||
type CaptureStats struct {
|
|
||||||
Packets int64
|
|
||||||
Bytes int64
|
|
||||||
Dropped int64
|
|
||||||
}
|
|
||||||
|
|
||||||
// CaptureSession represents an active packet capture. Call Stop to end the
|
|
||||||
// capture and flush buffered packets.
|
|
||||||
type CaptureSession struct {
|
|
||||||
sess *capture.Session
|
|
||||||
engine *internal.Engine
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop ends the capture, flushes remaining packets, and detaches from the device.
|
|
||||||
// Safe to call multiple times.
|
|
||||||
func (cs *CaptureSession) Stop() {
|
|
||||||
if cs.engine != nil {
|
|
||||||
_ = cs.engine.SetCapture(nil)
|
|
||||||
cs.engine = nil
|
|
||||||
}
|
|
||||||
if cs.sess != nil {
|
|
||||||
cs.sess.Stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stats returns current capture counters.
|
|
||||||
func (cs *CaptureSession) Stats() CaptureStats {
|
|
||||||
s := cs.sess.Stats()
|
|
||||||
return CaptureStats{
|
|
||||||
Packets: s.Packets,
|
|
||||||
Bytes: s.Bytes,
|
|
||||||
Dropped: s.Dropped,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Done returns a channel that is closed when the capture's writer goroutine
|
|
||||||
// has fully exited and all buffered packets have been flushed.
|
|
||||||
func (cs *CaptureSession) Done() <-chan struct{} {
|
|
||||||
return cs.sess.Done()
|
|
||||||
}
|
|
||||||
@@ -24,7 +24,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util/capture"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -66,7 +65,7 @@ type Options struct {
|
|||||||
PrivateKey string
|
PrivateKey string
|
||||||
// ManagementURL overrides the default management server URL
|
// ManagementURL overrides the default management server URL
|
||||||
ManagementURL string
|
ManagementURL string
|
||||||
// PreSharedKey is the pre-shared key for the tunnel interface
|
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||||
PreSharedKey string
|
PreSharedKey string
|
||||||
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
||||||
LogOutput io.Writer
|
LogOutput io.Writer
|
||||||
@@ -82,9 +81,9 @@ type Options struct {
|
|||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
// BlockInbound blocks all inbound connections from peers
|
// BlockInbound blocks all inbound connections from peers
|
||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
// WireguardPort is the port for the tunnel interface. Use 0 for a random port.
|
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||||
WireguardPort *int
|
WireguardPort *int
|
||||||
// MTU is the MTU for the tunnel interface.
|
// MTU is the MTU for the WireGuard interface.
|
||||||
// Valid values are in the range 576..8192 bytes.
|
// Valid values are in the range 576..8192 bytes.
|
||||||
// If non-nil, this value overrides any value stored in the config file.
|
// If non-nil, this value overrides any value stored in the config file.
|
||||||
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
|
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
|
||||||
@@ -470,52 +469,6 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
|
|||||||
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// StartCapture begins capturing packets on this client's tunnel device.
|
|
||||||
// Only one capture can be active at a time; starting a new one stops the previous.
|
|
||||||
// Call StopCapture (or CaptureSession.Stop) to end it.
|
|
||||||
func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) {
|
|
||||||
engine, err := c.getEngine()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
var matcher capture.Matcher
|
|
||||||
if opts.Filter != "" {
|
|
||||||
m, err := capture.ParseFilter(opts.Filter)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse filter: %w", err)
|
|
||||||
}
|
|
||||||
matcher = m
|
|
||||||
}
|
|
||||||
|
|
||||||
sess, err := capture.NewSession(capture.Options{
|
|
||||||
Output: opts.Output,
|
|
||||||
TextOutput: opts.TextOutput,
|
|
||||||
Matcher: matcher,
|
|
||||||
Verbose: opts.Verbose,
|
|
||||||
ASCII: opts.ASCII,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("create capture session: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := engine.SetCapture(sess); err != nil {
|
|
||||||
sess.Stop()
|
|
||||||
return nil, fmt.Errorf("set capture: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return &CaptureSession{sess: sess, engine: engine}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StopCapture stops the active capture session if one is running.
|
|
||||||
func (c *Client) StopCapture() error {
|
|
||||||
engine, err := c.getEngine()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
return engine.SetCapture(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
// getEngine safely retrieves the engine from the client with proper locking.
|
// getEngine safely retrieves the engine from the client with proper locking.
|
||||||
// Returns ErrClientNotStarted if the client is not started.
|
// Returns ErrClientNotStarted if the client is not started.
|
||||||
// Returns ErrEngineNotStarted if the engine is not available.
|
// Returns ErrEngineNotStarted if the engine is not available.
|
||||||
|
|||||||
@@ -115,13 +115,12 @@ type Manager struct {
|
|||||||
|
|
||||||
localipmanager *localIPManager
|
localipmanager *localIPManager
|
||||||
|
|
||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||||
pendingCapture atomic.Pointer[forwarder.PacketCapture]
|
logger *nblog.Logger
|
||||||
logger *nblog.Logger
|
flowLogger nftypes.FlowLogger
|
||||||
flowLogger nftypes.FlowLogger
|
|
||||||
|
|
||||||
blockRule firewall.Rule
|
blockRule firewall.Rule
|
||||||
|
|
||||||
@@ -352,19 +351,6 @@ func (m *Manager) determineRouting() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPacketCapture sets or clears packet capture on the forwarder endpoint.
|
|
||||||
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
|
|
||||||
func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) {
|
|
||||||
if pc == nil {
|
|
||||||
m.pendingCapture.Store(nil)
|
|
||||||
} else {
|
|
||||||
m.pendingCapture.Store(&pc)
|
|
||||||
}
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.SetCapture(pc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// initForwarder initializes the forwarder, it disables routing on errors
|
// initForwarder initializes the forwarder, it disables routing on errors
|
||||||
func (m *Manager) initForwarder() error {
|
func (m *Manager) initForwarder() error {
|
||||||
if m.forwarder.Load() != nil {
|
if m.forwarder.Load() != nil {
|
||||||
@@ -386,11 +372,6 @@ func (m *Manager) initForwarder() error {
|
|||||||
|
|
||||||
m.forwarder.Store(forwarder)
|
m.forwarder.Store(forwarder)
|
||||||
|
|
||||||
// Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture.
|
|
||||||
if pc := m.pendingCapture.Load(); pc != nil {
|
|
||||||
forwarder.SetCapture(*pc)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debug("forwarder initialized")
|
log.Debug("forwarder initialized")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -633,7 +614,6 @@ func (m *Manager) resetState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
fwder.SetCapture(nil)
|
|
||||||
fwder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,19 +12,12 @@ import (
|
|||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
)
|
)
|
||||||
|
|
||||||
// PacketCapture captures raw packets for debugging. Implementations must be
|
|
||||||
// safe for concurrent use and must not block.
|
|
||||||
type PacketCapture interface {
|
|
||||||
Offer(data []byte, outbound bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||||
type endpoint struct {
|
type endpoint struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
dispatcher stack.NetworkDispatcher
|
dispatcher stack.NetworkDispatcher
|
||||||
device *wgdevice.Device
|
device *wgdevice.Device
|
||||||
mtu atomic.Uint32
|
mtu atomic.Uint32
|
||||||
capture atomic.Pointer[PacketCapture]
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||||
@@ -61,17 +54,13 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
pktBytes := data.AsSlice()
|
// Send the packet through WireGuard
|
||||||
|
|
||||||
address := netHeader.DestinationAddress()
|
address := netHeader.DestinationAddress()
|
||||||
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil {
|
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||||
|
if err != nil {
|
||||||
e.logger.Error1("CreateOutboundPacket: %v", err)
|
e.logger.Error1("CreateOutboundPacket: %v", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if pc := e.capture.Load(); pc != nil {
|
|
||||||
(*pc).Offer(pktBytes, true)
|
|
||||||
}
|
|
||||||
written++
|
written++
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -139,16 +139,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return f, nil
|
return f, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCapture sets or clears the packet capture on the forwarder endpoint.
|
|
||||||
// This captures outbound packets that bypass the FilteredDevice (netstack forwarding).
|
|
||||||
func (f *Forwarder) SetCapture(pc PacketCapture) {
|
|
||||||
if pc == nil {
|
|
||||||
f.endpoint.capture.Store(nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
f.endpoint.capture.Store(&pc)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||||
if len(payload) < header.IPv4MinimumSize {
|
if len(payload) < header.IPv4MinimumSize {
|
||||||
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||||
|
|||||||
@@ -270,9 +270,5 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []
|
|||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
if pc := f.endpoint.capture.Load(); pc != nil {
|
|
||||||
(*pc).Offer(fullPacket, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
return len(fullPacket)
|
return len(fullPacket)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package device
|
|||||||
import (
|
import (
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
)
|
)
|
||||||
@@ -29,20 +28,11 @@ type PacketFilter interface {
|
|||||||
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
// PacketCapture captures raw packets for debugging. Implementations must be
|
|
||||||
// safe for concurrent use and must not block.
|
|
||||||
type PacketCapture interface {
|
|
||||||
// Offer submits a packet for capture. outbound is true for packets
|
|
||||||
// leaving the host (Read path), false for packets arriving (Write path).
|
|
||||||
Offer(data []byte, outbound bool)
|
|
||||||
}
|
|
||||||
|
|
||||||
// FilteredDevice to override Read or Write of packets
|
// FilteredDevice to override Read or Write of packets
|
||||||
type FilteredDevice struct {
|
type FilteredDevice struct {
|
||||||
tun.Device
|
tun.Device
|
||||||
|
|
||||||
filter PacketFilter
|
filter PacketFilter
|
||||||
capture atomic.Pointer[PacketCapture]
|
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
@@ -73,25 +63,20 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||||
return 0, err
|
return 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
d.mutex.RLock()
|
d.mutex.RLock()
|
||||||
filter := d.filter
|
filter := d.filter
|
||||||
d.mutex.RUnlock()
|
d.mutex.RUnlock()
|
||||||
|
|
||||||
if filter != nil {
|
if filter == nil {
|
||||||
for i := 0; i < n; i++ {
|
return
|
||||||
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
|
||||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
|
||||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
|
||||||
n--
|
|
||||||
i--
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if pc := d.capture.Load(); pc != nil {
|
for i := 0; i < n; i++ {
|
||||||
for i := 0; i < n; i++ {
|
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||||
(*pc).Offer(bufs[i][offset:offset+sizes[i]], true)
|
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||||
|
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||||
|
n--
|
||||||
|
i--
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -100,13 +85,6 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
|
|
||||||
// Write wraps write method with filtering feature
|
// Write wraps write method with filtering feature
|
||||||
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
||||||
// Capture before filtering so dropped packets are still visible in captures.
|
|
||||||
if pc := d.capture.Load(); pc != nil {
|
|
||||||
for _, buf := range bufs {
|
|
||||||
(*pc).Offer(buf[offset:], false)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
d.mutex.RLock()
|
d.mutex.RLock()
|
||||||
filter := d.filter
|
filter := d.filter
|
||||||
d.mutex.RUnlock()
|
d.mutex.RUnlock()
|
||||||
@@ -118,10 +96,9 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
filteredBufs := make([][]byte, 0, len(bufs))
|
||||||
dropped := 0
|
dropped := 0
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
if filter.FilterInbound(buf[offset:], len(buf)) {
|
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||||
dropped++
|
|
||||||
} else {
|
|
||||||
filteredBufs = append(filteredBufs, buf)
|
filteredBufs = append(filteredBufs, buf)
|
||||||
|
dropped++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,14 +113,3 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
|
|||||||
d.filter = filter
|
d.filter = filter
|
||||||
d.mutex.Unlock()
|
d.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
|
|
||||||
// Uses atomic store so the hot path (Read/Write) is a single pointer load
|
|
||||||
// with no locking overhead when capture is off.
|
|
||||||
func (d *FilteredDevice) SetCapture(pc PacketCapture) {
|
|
||||||
if pc == nil {
|
|
||||||
d.capture.Store(nil)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
d.capture.Store(&pc)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
t.Errorf("unexpected error: %v", err)
|
t.Errorf("unexpected error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if n != 1 {
|
if n != 0 {
|
||||||
t.Errorf("expected n=1, got %d", n)
|
t.Errorf("expected n=1, got %d", n)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -201,18 +201,7 @@ Pop $0
|
|||||||
|
|
||||||
Function .onInit
|
Function .onInit
|
||||||
StrCpy $INSTDIR "${INSTALL_DIR}"
|
StrCpy $INSTDIR "${INSTALL_DIR}"
|
||||||
; Default autostart to enabled so silent installs (/S) match the interactive default
|
|
||||||
StrCpy $AutostartEnabled "1"
|
|
||||||
|
|
||||||
; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live
|
|
||||||
; in the 32-bit view. Fall back to it so upgrades still find them.
|
|
||||||
SetRegView 64
|
|
||||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
||||||
${If} $R0 == ""
|
|
||||||
SetRegView 32
|
|
||||||
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
|
|
||||||
SetRegView 64
|
|
||||||
${EndIf}
|
|
||||||
${If} $R0 != ""
|
${If} $R0 != ""
|
||||||
# if silent install jump to uninstall step
|
# if silent install jump to uninstall step
|
||||||
IfSilent uninstall
|
IfSilent uninstall
|
||||||
@@ -225,10 +214,6 @@ ${If} $R0 != ""
|
|||||||
|
|
||||||
${EndIf}
|
${EndIf}
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
|
||||||
Function un.onInit
|
|
||||||
SetRegView 64
|
|
||||||
FunctionEnd
|
|
||||||
######################################################################
|
######################################################################
|
||||||
Section -MainProgram
|
Section -MainProgram
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
@@ -243,7 +228,6 @@ Section -MainProgram
|
|||||||
!else
|
!else
|
||||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||||
!endif
|
!endif
|
||||||
File "..\\client\\ui\\assets\\netbird.png"
|
|
||||||
SectionEnd
|
SectionEnd
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
@@ -263,11 +247,9 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
|
|||||||
; Create autostart registry entry based on checkbox
|
; Create autostart registry entry based on checkbox
|
||||||
DetailPrint "Autostart enabled: $AutostartEnabled"
|
DetailPrint "Autostart enabled: $AutostartEnabled"
|
||||||
${If} $AutostartEnabled == "1"
|
${If} $AutostartEnabled == "1"
|
||||||
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"'
|
WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
|
||||||
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
|
||||||
${Else}
|
${Else}
|
||||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
|
||||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
|
||||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
DetailPrint "Autostart not enabled by user"
|
DetailPrint "Autostart not enabled by user"
|
||||||
${EndIf}
|
${EndIf}
|
||||||
@@ -301,8 +283,6 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
|||||||
|
|
||||||
; Remove autostart registry entry
|
; Remove autostart registry entry
|
||||||
DetailPrint "Removing autostart registry entry if exists..."
|
DetailPrint "Removing autostart registry entry if exists..."
|
||||||
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
|
||||||
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
|
|
||||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
|
||||||
; Handle data deletion based on checkbox
|
; Handle data deletion based on checkbox
|
||||||
@@ -341,7 +321,6 @@ DetailPrint "Removing registry keys..."
|
|||||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||||
DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}"
|
|
||||||
|
|
||||||
DetailPrint "Removing application directory from PATH..."
|
DetailPrint "Removing application directory from PATH..."
|
||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
|
|||||||
@@ -333,10 +333,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
c.statusRecorder.MarkSignalConnected()
|
c.statusRecorder.MarkSignalConnected()
|
||||||
|
|
||||||
relayURLs, token := parseRelayInfo(loginResp)
|
relayURLs, token := parseRelayInfo(loginResp)
|
||||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
|
||||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
|
||||||
relayURLs = override
|
|
||||||
}
|
|
||||||
peerConfig := loginResp.GetPeerConfig()
|
peerConfig := loginResp.GetPeerConfig()
|
||||||
|
|
||||||
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)
|
||||||
|
|||||||
@@ -61,7 +61,6 @@ allocs.prof: Allocations profiling information.
|
|||||||
threadcreate.prof: Thread creation profiling information.
|
threadcreate.prof: Thread creation profiling information.
|
||||||
cpu.prof: CPU profiling information.
|
cpu.prof: CPU profiling information.
|
||||||
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
|
||||||
capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data.
|
|
||||||
|
|
||||||
|
|
||||||
Anonymization Process
|
Anonymization Process
|
||||||
@@ -235,7 +234,6 @@ type BundleGenerator struct {
|
|||||||
logPath string
|
logPath string
|
||||||
tempDir string
|
tempDir string
|
||||||
cpuProfile []byte
|
cpuProfile []byte
|
||||||
capturePath string
|
|
||||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
clientMetrics MetricsExporter
|
clientMetrics MetricsExporter
|
||||||
|
|
||||||
@@ -259,8 +257,7 @@ type GeneratorDependencies struct {
|
|||||||
LogPath string
|
LogPath string
|
||||||
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
|
||||||
CPUProfile []byte
|
CPUProfile []byte
|
||||||
CapturePath string
|
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
RefreshStatus func()
|
|
||||||
ClientMetrics MetricsExporter
|
ClientMetrics MetricsExporter
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -280,7 +277,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
logPath: deps.LogPath,
|
logPath: deps.LogPath,
|
||||||
tempDir: deps.TempDir,
|
tempDir: deps.TempDir,
|
||||||
cpuProfile: deps.CPUProfile,
|
cpuProfile: deps.CPUProfile,
|
||||||
capturePath: deps.CapturePath,
|
|
||||||
refreshStatus: deps.RefreshStatus,
|
refreshStatus: deps.RefreshStatus,
|
||||||
clientMetrics: deps.ClientMetrics,
|
clientMetrics: deps.ClientMetrics,
|
||||||
|
|
||||||
@@ -350,10 +346,6 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
|
log.Errorf("failed to add CPU profile to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := g.addCaptureFile(); err != nil {
|
|
||||||
log.Errorf("failed to add capture file to debug bundle: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := g.addStackTrace(); err != nil {
|
if err := g.addStackTrace(); err != nil {
|
||||||
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
log.Errorf("failed to add stack trace to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
@@ -607,12 +599,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
if g.internalConfig.EnableSSHRemotePortForwarding != nil {
|
||||||
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
|
||||||
}
|
}
|
||||||
if g.internalConfig.DisableSSHAuth != nil {
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
|
|
||||||
}
|
|
||||||
if g.internalConfig.SSHJWTCacheTTL != nil {
|
|
||||||
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
|
|
||||||
}
|
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
|
||||||
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
|
||||||
@@ -639,7 +625,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
|
|||||||
}
|
}
|
||||||
|
|
||||||
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
|
||||||
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addProf() (err error) {
|
func (g *BundleGenerator) addProf() (err error) {
|
||||||
@@ -684,29 +669,6 @@ func (g *BundleGenerator) addCPUProfile() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addCaptureFile() error {
|
|
||||||
if g.capturePath == "" {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if g.anonymize {
|
|
||||||
log.Info("skipping capture file in anonymized bundle (contains raw packet data)")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := os.Open(g.capturePath)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("open capture file: %w", err)
|
|
||||||
}
|
|
||||||
defer f.Close()
|
|
||||||
|
|
||||||
if err := g.addFileToZip(f, "capture.pcap"); err != nil {
|
|
||||||
return fmt.Errorf("add capture file to zip: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (g *BundleGenerator) addStackTrace() error {
|
func (g *BundleGenerator) addStackTrace() error {
|
||||||
buf := make([]byte, 5242880) // 5 MB buffer
|
buf := make([]byte, 5242880) // 5 MB buffer
|
||||||
n := runtime.Stack(buf, true)
|
n := runtime.Stack(buf, true)
|
||||||
|
|||||||
@@ -5,21 +5,16 @@ import (
|
|||||||
"bytes"
|
"bytes"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"net"
|
"net"
|
||||||
"net/url"
|
|
||||||
"os"
|
"os"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
"reflect"
|
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/configs"
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -476,8 +471,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
|||||||
anonymize: false,
|
anonymize: false,
|
||||||
input: map[string]any{
|
input: map[string]any{
|
||||||
jsonKeyServiceEnv: map[string]any{
|
jsonKeyServiceEnv: map[string]any{
|
||||||
"HOME": "/root",
|
"HOME": "/root",
|
||||||
"PATH": "/usr/bin",
|
"PATH": "/usr/bin",
|
||||||
"NB_LOG_LEVEL": "debug",
|
"NB_LOG_LEVEL": "debug",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -494,9 +489,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
|
|||||||
anonymize: false,
|
anonymize: false,
|
||||||
input: map[string]any{
|
input: map[string]any{
|
||||||
jsonKeyServiceEnv: map[string]any{
|
jsonKeyServiceEnv: map[string]any{
|
||||||
"NB_SETUP_KEY": "abc123",
|
"NB_SETUP_KEY": "abc123",
|
||||||
"NB_API_TOKEN": "tok_xyz",
|
"NB_API_TOKEN": "tok_xyz",
|
||||||
"NB_LOG_LEVEL": "info",
|
"NB_LOG_LEVEL": "info",
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
check: func(t *testing.T, params map[string]any) {
|
check: func(t *testing.T, params map[string]any) {
|
||||||
@@ -771,127 +766,3 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
|||||||
assert.Contains(t, anonNftables, "chain input {")
|
assert.Contains(t, anonNftables, "chain input {")
|
||||||
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||||
}
|
}
|
||||||
|
|
||||||
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
|
|
||||||
// profilemanager.Config is either rendered in the debug bundle or explicitly
|
|
||||||
// excluded. When a new field is added to Config, this test fails until the
|
|
||||||
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
|
|
||||||
// the excluded set with a justification.
|
|
||||||
func TestAddConfig_AllFieldsCovered(t *testing.T) {
|
|
||||||
excluded := map[string]string{
|
|
||||||
"PrivateKey": "sensitive: WireGuard private key",
|
|
||||||
"PreSharedKey": "sensitive: WireGuard pre-shared key",
|
|
||||||
"SSHKey": "sensitive: SSH private key",
|
|
||||||
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
|
|
||||||
}
|
|
||||||
|
|
||||||
mURL, _ := url.Parse("https://api.example.com:443")
|
|
||||||
aURL, _ := url.Parse("https://admin.example.com:443")
|
|
||||||
bTrue := true
|
|
||||||
iVal := 42
|
|
||||||
cfg := &profilemanager.Config{
|
|
||||||
PrivateKey: "priv",
|
|
||||||
PreSharedKey: "psk",
|
|
||||||
ManagementURL: mURL,
|
|
||||||
AdminURL: aURL,
|
|
||||||
WgIface: "wt0",
|
|
||||||
WgPort: 51820,
|
|
||||||
NetworkMonitor: &bTrue,
|
|
||||||
IFaceBlackList: []string{"eth0"},
|
|
||||||
DisableIPv6Discovery: true,
|
|
||||||
RosenpassEnabled: true,
|
|
||||||
RosenpassPermissive: true,
|
|
||||||
ServerSSHAllowed: &bTrue,
|
|
||||||
EnableSSHRoot: &bTrue,
|
|
||||||
EnableSSHSFTP: &bTrue,
|
|
||||||
EnableSSHLocalPortForwarding: &bTrue,
|
|
||||||
EnableSSHRemotePortForwarding: &bTrue,
|
|
||||||
DisableSSHAuth: &bTrue,
|
|
||||||
SSHJWTCacheTTL: &iVal,
|
|
||||||
DisableClientRoutes: true,
|
|
||||||
DisableServerRoutes: true,
|
|
||||||
DisableDNS: true,
|
|
||||||
DisableFirewall: true,
|
|
||||||
BlockLANAccess: true,
|
|
||||||
BlockInbound: true,
|
|
||||||
DisableNotifications: &bTrue,
|
|
||||||
DNSLabels: domain.List{},
|
|
||||||
SSHKey: "sshkey",
|
|
||||||
NATExternalIPs: []string{"1.2.3.4"},
|
|
||||||
CustomDNSAddress: "1.1.1.1:53",
|
|
||||||
DisableAutoConnect: true,
|
|
||||||
DNSRouteInterval: 5 * time.Second,
|
|
||||||
ClientCertPath: "/tmp/cert",
|
|
||||||
ClientCertKeyPath: "/tmp/key",
|
|
||||||
LazyConnectionEnabled: true,
|
|
||||||
MTU: 1280,
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, anonymize := range []bool{false, true} {
|
|
||||||
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
|
|
||||||
g := &BundleGenerator{
|
|
||||||
anonymizer: newAnonymizerForTest(),
|
|
||||||
internalConfig: cfg,
|
|
||||||
anonymize: anonymize,
|
|
||||||
}
|
|
||||||
|
|
||||||
var sb strings.Builder
|
|
||||||
g.addCommonConfigFields(&sb)
|
|
||||||
rendered := sb.String() + renderAddConfigSpecific(g)
|
|
||||||
|
|
||||||
val := reflect.ValueOf(cfg).Elem()
|
|
||||||
typ := val.Type()
|
|
||||||
var missing []string
|
|
||||||
for i := 0; i < typ.NumField(); i++ {
|
|
||||||
name := typ.Field(i).Name
|
|
||||||
if _, ok := excluded[name]; ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !strings.Contains(rendered, name+":") {
|
|
||||||
missing = append(missing, name)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(missing) > 0 {
|
|
||||||
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
|
|
||||||
"Either render the field in addCommonConfigFields/addConfig, "+
|
|
||||||
"or add it to the excluded map with a justification.", missing)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
|
|
||||||
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
|
|
||||||
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
|
|
||||||
// production shape without needing to write an actual zip.
|
|
||||||
func renderAddConfigSpecific(g *BundleGenerator) string {
|
|
||||||
var sb strings.Builder
|
|
||||||
if g.anonymize {
|
|
||||||
if g.internalConfig.ManagementURL != nil {
|
|
||||||
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
|
|
||||||
}
|
|
||||||
if g.internalConfig.AdminURL != nil {
|
|
||||||
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
|
|
||||||
}
|
|
||||||
sb.WriteString("NATExternalIPs: x\n")
|
|
||||||
if g.internalConfig.CustomDNSAddress != "" {
|
|
||||||
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if g.internalConfig.ManagementURL != nil {
|
|
||||||
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
|
|
||||||
}
|
|
||||||
if g.internalConfig.AdminURL != nil {
|
|
||||||
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
|
|
||||||
}
|
|
||||||
sb.WriteString("NATExternalIPs: x\n")
|
|
||||||
if g.internalConfig.CustomDNSAddress != "" {
|
|
||||||
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return sb.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newAnonymizerForTest() *anonymize.Anonymizer {
|
|
||||||
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
defaultResolvConfPath = "/etc/resolv.conf"
|
defaultResolvConfPath = "/etc/resolv.conf"
|
||||||
nsswitchConfPath = "/etc/nsswitch.conf"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type resolvConf struct {
|
type resolvConf struct {
|
||||||
|
|||||||
@@ -1,10 +1,7 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"math"
|
|
||||||
"net"
|
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
c.dispatch(w, r, math.MaxInt)
|
|
||||||
}
|
|
||||||
|
|
||||||
// dispatch routes a DNS request through the chain, skipping handlers with
|
|
||||||
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
|
|
||||||
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
|
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
|
|||||||
|
|
||||||
// Try handlers in priority order
|
// Try handlers in priority order
|
||||||
for _, entry := range handlers {
|
for _, entry := range handlers {
|
||||||
if entry.Priority > maxPriority {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !c.isHandlerMatch(qname, entry) {
|
if !c.isHandlerMatch(qname, entry) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
|
|||||||
cw.response.Len(), meta, time.Since(startTime))
|
cw.response.Len(), meta, time.Since(startTime))
|
||||||
}
|
}
|
||||||
|
|
||||||
// ResolveInternal runs an in-process DNS query against the chain, skipping any
|
|
||||||
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
|
|
||||||
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
|
|
||||||
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
|
|
||||||
// (bounded by the invoked handler's internal timeout).
|
|
||||||
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
|
||||||
if len(r.Question) == 0 {
|
|
||||||
return nil, fmt.Errorf("empty question")
|
|
||||||
}
|
|
||||||
|
|
||||||
base := &internalResponseWriter{}
|
|
||||||
done := make(chan struct{})
|
|
||||||
go func() {
|
|
||||||
c.dispatch(base, r, maxPriority)
|
|
||||||
close(done)
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
case <-ctx.Done():
|
|
||||||
// Prefer a completed response if dispatch finished concurrently with cancellation.
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
default:
|
|
||||||
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
|
|
||||||
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
|
|
||||||
strings.ToLower(r.Question[0].Name), maxPriority)
|
|
||||||
}
|
|
||||||
return base.response, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
|
|
||||||
// priority ≤ maxPriority.
|
|
||||||
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
for _, h := range c.handlers {
|
|
||||||
if h.Pattern == "." && h.Priority <= maxPriority {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
switch {
|
switch {
|
||||||
case entry.Pattern == ".":
|
case entry.Pattern == ".":
|
||||||
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// internalResponseWriter captures a dns.Msg for in-process chain queries.
|
|
||||||
type internalResponseWriter struct {
|
|
||||||
response *dns.Msg
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
|
|
||||||
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
|
|
||||||
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
|
|
||||||
|
|
||||||
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
|
|
||||||
// still surface their answer to ResolveInternal.
|
|
||||||
func (w *internalResponseWriter) Write(p []byte) (int, error) {
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
if err := msg.Unpack(p); err != nil {
|
|
||||||
return 0, err
|
|
||||||
}
|
|
||||||
w.response = msg
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *internalResponseWriter) Close() error { return nil }
|
|
||||||
func (w *internalResponseWriter) TsigStatus() error { return nil }
|
|
||||||
|
|
||||||
// TsigTimersOnly is part of dns.ResponseWriter.
|
|
||||||
func (w *internalResponseWriter) TsigTimersOnly(bool) {
|
|
||||||
// no-op: in-process queries carry no TSIG state.
|
|
||||||
}
|
|
||||||
|
|
||||||
// Hijack is part of dns.ResponseWriter.
|
|
||||||
func (w *internalResponseWriter) Hijack() {
|
|
||||||
// no-op: in-process queries have no underlying connection to hand off.
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,15 +1,11 @@
|
|||||||
package dns_test
|
package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// answeringHandler writes a fixed A record to ack the query. Used to verify
|
|
||||||
// which handler ResolveInternal dispatches to.
|
|
||||||
type answeringHandler struct {
|
|
||||||
name string
|
|
||||||
ip string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(r)
|
|
||||||
resp.Answer = []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP(h.ip).To4(),
|
|
||||||
}}
|
|
||||||
_ = w.WriteMsg(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *answeringHandler) String() string { return h.name }
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
|
|
||||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
|
||||||
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
|
|
||||||
|
|
||||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
|
||||||
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
assert.NotNil(t, resp)
|
|
||||||
assert.Equal(t, 1, len(resp.Answer))
|
|
||||||
a, ok := resp.Answer[0].(*dns.A)
|
|
||||||
assert.True(t, ok)
|
|
||||||
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
|
|
||||||
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
|
||||||
assert.Error(t, err, "no handler at or below maxPriority should error")
|
|
||||||
}
|
|
||||||
|
|
||||||
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
|
|
||||||
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
|
|
||||||
type rawWriteHandler struct {
|
|
||||||
ip string
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(r)
|
|
||||||
resp.Answer = []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP(h.ip).To4(),
|
|
||||||
}}
|
|
||||||
packed, err := resp.Pack()
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
_, _ = w.Write(packed)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
|
|
||||||
assert.NoError(t, err)
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
require.Len(t, resp.Answer, 1)
|
|
||||||
a, ok := resp.Answer[0].(*dns.A)
|
|
||||||
require.True(t, ok)
|
|
||||||
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
|
|
||||||
assert.Error(t, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
|
|
||||||
type hangingHandler struct {
|
|
||||||
block chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|
||||||
<-h.block
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(r)
|
|
||||||
_ = w.WriteMsg(resp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (h *hangingHandler) String() string { return "hangingHandler" }
|
|
||||||
|
|
||||||
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
h := &hangingHandler{block: make(chan struct{})}
|
|
||||||
defer close(h.block)
|
|
||||||
|
|
||||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
r := new(dns.Msg)
|
|
||||||
r.SetQuestion("example.com.", dns.TypeA)
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
start := time.Now()
|
|
||||||
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
|
|
||||||
elapsed := time.Since(start)
|
|
||||||
|
|
||||||
assert.Error(t, err)
|
|
||||||
assert.ErrorIs(t, err, context.DeadlineExceeded)
|
|
||||||
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
|
|
||||||
chain := nbdns.NewHandlerChain()
|
|
||||||
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
|
|
||||||
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
|
|
||||||
|
|
||||||
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
|
|
||||||
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
|
|
||||||
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityDefault)
|
|
||||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
|
|
||||||
|
|
||||||
chain.RemoveHandler(".", nbdns.PriorityDefault)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
|
||||||
|
|
||||||
// Primary nsgroup case: root handler lands at PriorityUpstream.
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityUpstream)
|
|
||||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
|
|
||||||
chain.RemoveHandler(".", nbdns.PriorityUpstream)
|
|
||||||
|
|
||||||
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
|
|
||||||
chain.AddHandler(".", h, nbdns.PriorityFallback)
|
|
||||||
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
|
|
||||||
chain.RemoveHandler(".", nbdns.PriorityFallback)
|
|
||||||
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -46,12 +46,12 @@ type restoreHostManager interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newHostManager(wgInterface string) (hostManager, error) {
|
func newHostManager(wgInterface string) (hostManager, error) {
|
||||||
osManager, reason, err := getOSDNSManagerType()
|
osManager, err := getOSDNSManagerType()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
return nil, fmt.Errorf("get os dns manager type: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason)
|
log.Infof("System DNS manager discovered: %s", osManager)
|
||||||
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
mgr, err := newHostManagerFromType(wgInterface, osManager)
|
||||||
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func getOSDNSManagerType() (osManagerType, string, error) {
|
func getOSDNSManagerType() (osManagerType, error) {
|
||||||
resolved := isSystemdResolvedRunning()
|
|
||||||
nss := isLibnssResolveUsed()
|
|
||||||
stub := checkStub()
|
|
||||||
|
|
||||||
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
|
|
||||||
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
|
|
||||||
// that go through nss-resolve, and in foreign mode they can loop back
|
|
||||||
// through resolved as an upstream.
|
|
||||||
if resolved && (nss || stub) {
|
|
||||||
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
mgr, reason, rejected, err := scanResolvConfHeader()
|
|
||||||
if err != nil {
|
|
||||||
return 0, "", err
|
|
||||||
}
|
|
||||||
if reason != "" {
|
|
||||||
return mgr, reason, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
|
|
||||||
if len(rejected) > 0 {
|
|
||||||
fallback += "; rejected: " + strings.Join(rejected, ", ")
|
|
||||||
}
|
|
||||||
return fileManager, fallback, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
|
|
||||||
// matching manager. If reason is empty the caller should pick file mode and
|
|
||||||
// use rejected for diagnostics.
|
|
||||||
func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
|
||||||
file, err := os.Open(defaultResolvConfPath)
|
file, err := os.Open(defaultResolvConfPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if cerr := file.Close(); cerr != nil {
|
if err := file.Close(); err != nil {
|
||||||
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr)
|
log.Errorf("close file %s: %s", defaultResolvConfPath, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rejected []string
|
|
||||||
scanner := bufio.NewScanner(file)
|
scanner := bufio.NewScanner(file)
|
||||||
for scanner.Scan() {
|
for scanner.Scan() {
|
||||||
text := scanner.Text()
|
text := scanner.Text()
|
||||||
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if text[0] != '#' {
|
if text[0] != '#' {
|
||||||
break
|
return fileManager, nil
|
||||||
}
|
}
|
||||||
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" {
|
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
||||||
return mgr, reason, nil, nil
|
return netbirdManager, nil
|
||||||
} else if rej != "" {
|
}
|
||||||
rejected = append(rejected, rej)
|
if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
||||||
|
return networkManager, nil
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
|
||||||
|
if checkStub() {
|
||||||
|
return systemdManager, nil
|
||||||
|
} else {
|
||||||
|
return fileManager, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if strings.Contains(text, "resolvconf") {
|
||||||
|
if isSystemdResolveConfMode() {
|
||||||
|
return systemdManager, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return resolvConfManager, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if err := scanner.Err(); err != nil && err != io.EOF {
|
if err := scanner.Err(); err != nil && err != io.EOF {
|
||||||
return 0, "", nil, fmt.Errorf("scan: %w", err)
|
return 0, fmt.Errorf("scan: %w", err)
|
||||||
}
|
}
|
||||||
return 0, "", rejected, nil
|
|
||||||
|
return fileManager, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// matchResolvConfHeader inspects a single comment line. Returns either a
|
// checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
|
||||||
// definitive (manager, reason) or a non-empty rejected diagnostic.
|
|
||||||
func matchResolvConfHeader(text string) (osManagerType, string, string) {
|
|
||||||
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
|
|
||||||
return netbirdManager, "netbird-managed resolv.conf header detected", ""
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "NetworkManager") {
|
|
||||||
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
|
|
||||||
return networkManager, "NetworkManager header + supported version on dbus", ""
|
|
||||||
}
|
|
||||||
return 0, "", "NetworkManager header (no dbus or unsupported version)"
|
|
||||||
}
|
|
||||||
if strings.Contains(text, "resolvconf") {
|
|
||||||
if isSystemdResolveConfMode() {
|
|
||||||
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
|
|
||||||
}
|
|
||||||
return resolvConfManager, "resolvconf header detected", ""
|
|
||||||
}
|
|
||||||
return 0, "", ""
|
|
||||||
}
|
|
||||||
|
|
||||||
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
|
|
||||||
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
|
|
||||||
// into file mode while resolved is active.
|
|
||||||
func checkStub() bool {
|
func checkStub() bool {
|
||||||
rConf, err := parseDefaultResolvConf()
|
rConf, err := parseDefaultResolvConf()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err)
|
log.Warnf("failed to parse resolv conf: %s", err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,36 +139,3 @@ func checkStub() bool {
|
|||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
|
|
||||||
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
|
|
||||||
// delegated to systemd-resolved regardless of /etc/resolv.conf.
|
|
||||||
func isLibnssResolveUsed() bool {
|
|
||||||
bs, err := os.ReadFile(nsswitchConfPath)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("read %s: %v", nsswitchConfPath, err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return parseNsswitchResolveAhead(bs)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseNsswitchResolveAhead(data []byte) bool {
|
|
||||||
for _, line := range strings.Split(string(data), "\n") {
|
|
||||||
if i := strings.IndexByte(line, '#'); i >= 0 {
|
|
||||||
line = line[:i]
|
|
||||||
}
|
|
||||||
fields := strings.Fields(line)
|
|
||||||
if len(fields) < 2 || fields[0] != "hosts:" {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
for _, module := range fields[1:] {
|
|
||||||
switch module {
|
|
||||||
case "dns":
|
|
||||||
return false
|
|
||||||
case "resolve":
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,76 +0,0 @@
|
|||||||
//go:build (linux && !android) || freebsd
|
|
||||||
|
|
||||||
package dns
|
|
||||||
|
|
||||||
import "testing"
|
|
||||||
|
|
||||||
func TestParseNsswitchResolveAhead(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
in string
|
|
||||||
want bool
|
|
||||||
}{
|
|
||||||
{
|
|
||||||
name: "resolve before dns with action token",
|
|
||||||
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "dns before resolve",
|
|
||||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "debian default with only dns",
|
|
||||||
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "neither resolve nor dns",
|
|
||||||
in: "hosts: files myhostname\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "no hosts line",
|
|
||||||
in: "passwd: files systemd\ngroup: files systemd\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "empty",
|
|
||||||
in: "",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "comments and blank lines ignored",
|
|
||||||
in: "# comment\n\n# another\nhosts: resolve dns\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "trailing inline comment",
|
|
||||||
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "hosts token must be the first field",
|
|
||||||
in: " hosts: resolve dns\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "other db line mentioning resolve is ignored",
|
|
||||||
in: "networks: resolve\nhosts: dns\n",
|
|
||||||
want: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "only resolve, no dns",
|
|
||||||
in: "hosts: files resolve\n",
|
|
||||||
want: true,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
for _, tt := range tests {
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
|
||||||
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
|
|
||||||
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -2,83 +2,40 @@ package mgmt
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
"os"
|
|
||||||
"slices"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sync/singleflight"
|
|
||||||
|
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/resutil"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const dnsTimeout = 5 * time.Second
|
||||||
dnsTimeout = 5 * time.Second
|
|
||||||
defaultTTL = 300 * time.Second
|
|
||||||
refreshBackoff = 30 * time.Second
|
|
||||||
|
|
||||||
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing.
|
// Resolver caches critical NetBird infrastructure domains
|
||||||
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
|
|
||||||
)
|
|
||||||
|
|
||||||
// ChainResolver lets the cache refresh stale entries through the DNS handler
|
|
||||||
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
|
|
||||||
// system resolver.
|
|
||||||
type ChainResolver interface {
|
|
||||||
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
|
|
||||||
HasRootHandlerAtOrBelow(maxPriority int) bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
|
|
||||||
// records and cachedAt are set at construction and treated as immutable;
|
|
||||||
// lastFailedRefresh and consecFailures are mutable and must be accessed under
|
|
||||||
// Resolver.mutex.
|
|
||||||
type cachedRecord struct {
|
|
||||||
records []dns.RR
|
|
||||||
cachedAt time.Time
|
|
||||||
lastFailedRefresh time.Time
|
|
||||||
consecFailures int
|
|
||||||
}
|
|
||||||
|
|
||||||
// Resolver caches critical NetBird infrastructure domains.
|
|
||||||
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
|
|
||||||
type Resolver struct {
|
type Resolver struct {
|
||||||
records map[dns.Question]*cachedRecord
|
records map[dns.Question][]dns.RR
|
||||||
mgmtDomain *domain.Domain
|
mgmtDomain *domain.Domain
|
||||||
serverDomains *dnsconfig.ServerDomains
|
serverDomains *dnsconfig.ServerDomains
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
chain ChainResolver
|
type ipsResponse struct {
|
||||||
chainMaxPriority int
|
ips []netip.Addr
|
||||||
refreshGroup singleflight.Group
|
err error
|
||||||
|
|
||||||
// refreshing tracks questions whose refresh is running via the OS
|
|
||||||
// fallback path. A ServeDNS hit for a question in this map indicates
|
|
||||||
// the OS resolver routed the recursive query back to us (loop). Only
|
|
||||||
// the OS path arms this so chain-path refreshes don't produce false
|
|
||||||
// positives. The atomic bool is CAS-flipped once per refresh to
|
|
||||||
// throttle the warning log.
|
|
||||||
refreshing map[dns.Question]*atomic.Bool
|
|
||||||
|
|
||||||
cacheTTL time.Duration
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewResolver creates a new management domains cache resolver.
|
// NewResolver creates a new management domains cache resolver.
|
||||||
func NewResolver() *Resolver {
|
func NewResolver() *Resolver {
|
||||||
return &Resolver{
|
return &Resolver{
|
||||||
records: make(map[dns.Question]*cachedRecord),
|
records: make(map[dns.Question][]dns.RR),
|
||||||
refreshing: make(map[dns.Question]*atomic.Bool),
|
|
||||||
cacheTTL: resolveCacheTTL(),
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
|
|||||||
return "MgmtCacheResolver"
|
return "MgmtCacheResolver"
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetChainResolver wires the handler chain used to refresh stale cache entries.
|
// ServeDNS implements dns.Handler interface.
|
||||||
// maxPriority caps which handlers may answer refresh queries (typically
|
|
||||||
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
|
|
||||||
// mgmt/route/local handlers are skipped).
|
|
||||||
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
m.chain = chain
|
|
||||||
m.chainMaxPriority = maxPriority
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// ServeDNS serves cached A/AAAA records. Stale entries are returned
|
|
||||||
// immediately and refreshed asynchronously (stale-while-revalidate).
|
|
||||||
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
m.continueToNext(w, r)
|
m.continueToNext(w, r)
|
||||||
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
cached, found := m.records[question]
|
records, found := m.records[question]
|
||||||
inflight := m.refreshing[question]
|
|
||||||
var shouldRefresh bool
|
|
||||||
if found {
|
|
||||||
stale := time.Since(cached.cachedAt) > m.cacheTTL
|
|
||||||
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
|
|
||||||
shouldRefresh = stale && !inBackoff
|
|
||||||
}
|
|
||||||
m.mutex.RUnlock()
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
if !found {
|
if !found {
|
||||||
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if inflight != nil && inflight.CompareAndSwap(false, true) {
|
|
||||||
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
|
|
||||||
question.Name)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Skip scheduling a refresh goroutine if one is already inflight for
|
|
||||||
// this question; singleflight would dedup anyway but skipping avoids
|
|
||||||
// a parked goroutine per stale hit under bursty load.
|
|
||||||
if shouldRefresh && inflight == nil {
|
|
||||||
m.scheduleRefresh(question, cached)
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := &dns.Msg{}
|
resp := &dns.Msg{}
|
||||||
resp.SetReply(r)
|
resp.SetReply(r)
|
||||||
resp.Authoritative = false
|
resp.Authoritative = false
|
||||||
resp.RecursionAvailable = true
|
resp.RecursionAvailable = true
|
||||||
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
|
|
||||||
|
resp.Answer = append(resp.Answer, records...)
|
||||||
|
|
||||||
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
|
||||||
|
|
||||||
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddDomain resolves a domain and stores its A/AAAA records in the cache.
|
// AddDomain manually adds a domain to cache by resolving it.
|
||||||
// A family that resolves NODATA (nil err, zero records) evicts any stale
|
|
||||||
// entry for that qtype.
|
|
||||||
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
|
||||||
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName)
|
ips, err := lookupIPWithExtraTimeout(ctx, d)
|
||||||
|
if err != nil {
|
||||||
if errA != nil && errAAAA != nil {
|
return err
|
||||||
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(aRecords) == 0 && len(aaaaRecords) == 0 {
|
var aRecords, aaaaRecords []dns.RR
|
||||||
if err := errors.Join(errA, errAAAA); err != nil {
|
for _, ip := range ips {
|
||||||
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err)
|
if ip.Is4() {
|
||||||
|
rr := &dns.A{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: dnsName,
|
||||||
|
Rrtype: dns.TypeA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
A: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aRecords = append(aRecords, rr)
|
||||||
|
} else if ip.Is6() {
|
||||||
|
rr := &dns.AAAA{
|
||||||
|
Hdr: dns.RR_Header{
|
||||||
|
Name: dnsName,
|
||||||
|
Rrtype: dns.TypeAAAA,
|
||||||
|
Class: dns.ClassINET,
|
||||||
|
Ttl: 300,
|
||||||
|
},
|
||||||
|
AAAA: ip.AsSlice(),
|
||||||
|
}
|
||||||
|
aaaaRecords = append(aaaaRecords, rr)
|
||||||
}
|
}
|
||||||
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
now := time.Now()
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now)
|
if len(aRecords) > 0 {
|
||||||
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now)
|
aQuestion := dns.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Qtype: dns.TypeA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
m.records[aQuestion] = aRecords
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records",
|
if len(aaaaRecords) > 0 {
|
||||||
|
aaaaQuestion := dns.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
m.records[aaaaQuestion] = aaaaRecords
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("added domain=%s with %d A records and %d AAAA records",
|
||||||
d.SafeString(), len(aRecords), len(aaaaRecords))
|
d.SafeString(), len(aRecords), len(aaaaRecords))
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache
|
func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
|
||||||
// untouched on error. Caller holds m.mutex.
|
log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) {
|
defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
|
||||||
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET}
|
resultChan := make(chan *ipsResponse, 1)
|
||||||
switch {
|
|
||||||
case len(records) > 0:
|
|
||||||
m.records[q] = &cachedRecord{records: records, cachedAt: now}
|
|
||||||
case err == nil:
|
|
||||||
delete(m.records, q)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per
|
go func() {
|
||||||
// unique in-flight key; bursty stale hits share its channel. expected is the
|
ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
|
||||||
// cachedRecord pointer observed by the caller; the refresh only mutates the
|
resultChan <- &ipsResponse{
|
||||||
// cache if that pointer is still the one stored, so a stale in-flight refresh
|
err: err,
|
||||||
// can't clobber a newer entry written by AddDomain or a competing refresh.
|
ips: ips,
|
||||||
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
|
|
||||||
key := question.Name + "|" + dns.TypeToString[question.Qtype]
|
|
||||||
_ = m.refreshGroup.DoChan(key, func() (any, error) {
|
|
||||||
return nil, m.refreshQuestion(question, expected)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// refreshQuestion replaces the cached records on success, or marks the entry
|
|
||||||
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
|
|
||||||
// a resolver loop by spotting a query for this same question arriving on us.
|
|
||||||
// expected pins the cache entry observed at schedule time; mutations only apply
|
|
||||||
// if m.records[question] still points at it.
|
|
||||||
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
|
||||||
defer cancel()
|
|
||||||
|
|
||||||
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
|
|
||||||
if err != nil {
|
|
||||||
m.markRefreshFailed(question, expected)
|
|
||||||
return fmt.Errorf("parse domain: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
records, err := m.lookupRecords(ctx, d, question)
|
|
||||||
if err != nil {
|
|
||||||
fails := m.markRefreshFailed(question, expected)
|
|
||||||
logf := log.Warnf
|
|
||||||
if fails == 0 || fails > 1 {
|
|
||||||
logf = log.Debugf
|
|
||||||
}
|
}
|
||||||
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)",
|
}()
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
|
|
||||||
return err
|
var resp *ipsResponse
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-time.After(dnsTimeout + time.Millisecond*500):
|
||||||
|
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
|
||||||
|
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
|
||||||
|
case <-ctx.Done():
|
||||||
|
return nil, ctx.Err()
|
||||||
|
case resp = <-resultChan:
|
||||||
}
|
}
|
||||||
|
|
||||||
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale.
|
if resp.err != nil {
|
||||||
if len(records) == 0 {
|
return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
|
||||||
m.mutex.Lock()
|
|
||||||
if m.records[question] == expected {
|
|
||||||
delete(m.records, question)
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
return resp.ips, nil
|
||||||
now := time.Now()
|
|
||||||
m.mutex.Lock()
|
|
||||||
if m.records[question] != expected {
|
|
||||||
m.mutex.Unlock()
|
|
||||||
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
m.records[question] = &cachedRecord{records: records, cachedAt: now}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
log.Infof("refreshed mgmt cache domain=%s type=%s",
|
|
||||||
d.SafeString(), dns.TypeToString[question.Qtype])
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) markRefreshing(question dns.Question) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
m.refreshing[question] = &atomic.Bool{}
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Resolver) clearRefreshing(question dns.Question) {
|
|
||||||
m.mutex.Lock()
|
|
||||||
delete(m.refreshing, question)
|
|
||||||
m.mutex.Unlock()
|
|
||||||
}
|
|
||||||
|
|
||||||
// markRefreshFailed arms the backoff and returns the new consecutive-failure
|
|
||||||
// count so callers can downgrade subsequent failure logs to debug.
|
|
||||||
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
c, ok := m.records[question]
|
|
||||||
if !ok || c != expected {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
c.lastFailedRefresh = time.Now()
|
|
||||||
c.consecFailures++
|
|
||||||
return c.consecFailures
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
|
|
||||||
// callers tell records, NODATA (nil err, no records), and failure apart.
|
|
||||||
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
|
|
||||||
m.mutex.RLock()
|
|
||||||
chain := m.chain
|
|
||||||
maxPriority := m.chainMaxPriority
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
||||||
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
|
|
||||||
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: drop once every supported OS registers a fallback resolver. Safe
|
|
||||||
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
|
|
||||||
// not the system resolver, so net.DefaultResolver will not loop back.
|
|
||||||
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
|
|
||||||
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupRecords resolves a single record type via chain or OS. The OS branch
|
|
||||||
// arms the loop detector for the duration of its call so that ServeDNS can
|
|
||||||
// spot the OS resolver routing the recursive query back to us.
|
|
||||||
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
|
|
||||||
m.mutex.RLock()
|
|
||||||
chain := m.chain
|
|
||||||
maxPriority := m.chainMaxPriority
|
|
||||||
m.mutex.RUnlock()
|
|
||||||
|
|
||||||
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
|
|
||||||
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// TODO: drop once every supported OS registers a fallback resolver.
|
|
||||||
m.markRefreshing(q)
|
|
||||||
defer m.clearRefreshing(q)
|
|
||||||
|
|
||||||
return m.osLookup(ctx, d, q.Name, q.Qtype)
|
|
||||||
}
|
|
||||||
|
|
||||||
// lookupViaChain resolves via the handler chain and rewrites each RR to use
|
|
||||||
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
|
|
||||||
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
|
|
||||||
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
||||||
msg := &dns.Msg{}
|
|
||||||
msg.SetQuestion(dnsName, qtype)
|
|
||||||
msg.RecursionDesired = true
|
|
||||||
|
|
||||||
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("chain resolve: %w", err)
|
|
||||||
}
|
|
||||||
if resp == nil {
|
|
||||||
return nil, fmt.Errorf("chain resolve returned nil response")
|
|
||||||
}
|
|
||||||
if resp.Rcode != dns.RcodeSuccess {
|
|
||||||
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
|
|
||||||
}
|
|
||||||
|
|
||||||
ttl := uint32(m.cacheTTL.Seconds())
|
|
||||||
owners := cnameOwners(dnsName, resp.Answer)
|
|
||||||
var filtered []dns.RR
|
|
||||||
for _, rr := range resp.Answer {
|
|
||||||
h := rr.Header()
|
|
||||||
if h.Class != dns.ClassINET || h.Rrtype != qtype {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
|
|
||||||
filtered = append(filtered, cp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return filtered, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// osLookup resolves a single family via net.DefaultResolver using resutil,
|
|
||||||
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
|
|
||||||
// returns (nil, nil).
|
|
||||||
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
|
|
||||||
network := resutil.NetworkForQtype(qtype)
|
|
||||||
if network == "" {
|
|
||||||
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
||||||
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
|
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
|
|
||||||
if result.Rcode == dns.RcodeSuccess {
|
|
||||||
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
if result.Err != nil {
|
|
||||||
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
|
|
||||||
}
|
|
||||||
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
|
|
||||||
}
|
|
||||||
|
|
||||||
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
|
|
||||||
// so downstream resolvers don't cache an answer for longer than we will.
|
|
||||||
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
|
|
||||||
remaining := m.cacheTTL - time.Since(cachedAt)
|
|
||||||
if remaining <= 0 {
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
return uint32((remaining + time.Second - 1) / time.Second)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// PopulateFromConfig extracts and caches domains from the client configuration.
|
// PopulateFromConfig extracts and caches domains from the client configuration.
|
||||||
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
aQuestion := dns.Question{
|
||||||
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET}
|
Name: dnsName,
|
||||||
delete(m.records, qA)
|
Qtype: dns.TypeA,
|
||||||
delete(m.records, qAAAA)
|
Qclass: dns.ClassINET,
|
||||||
delete(m.refreshing, qA)
|
}
|
||||||
delete(m.refreshing, qAAAA)
|
delete(m.records, aQuestion)
|
||||||
|
|
||||||
|
aaaaQuestion := dns.Question{
|
||||||
|
Name: dnsName,
|
||||||
|
Qtype: dns.TypeAAAA,
|
||||||
|
Qclass: dns.ClassINET,
|
||||||
|
}
|
||||||
|
delete(m.records, aaaaQuestion)
|
||||||
|
|
||||||
log.Debugf("removed domain=%s from cache", d.SafeString())
|
log.Debugf("removed domain=%s from cache", d.SafeString())
|
||||||
return nil
|
return nil
|
||||||
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
|||||||
|
|
||||||
return domains
|
return domains
|
||||||
}
|
}
|
||||||
|
|
||||||
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
|
|
||||||
// A/AAAA records return nil.
|
|
||||||
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
|
|
||||||
switch r := rr.(type) {
|
|
||||||
case *dns.A:
|
|
||||||
cp := *r
|
|
||||||
cp.Hdr.Name = owner
|
|
||||||
cp.Hdr.Ttl = ttl
|
|
||||||
cp.A = slices.Clone(r.A)
|
|
||||||
return &cp
|
|
||||||
case *dns.AAAA:
|
|
||||||
cp := *r
|
|
||||||
cp.Hdr.Name = owner
|
|
||||||
cp.Hdr.Ttl = ttl
|
|
||||||
cp.AAAA = slices.Clone(r.AAAA)
|
|
||||||
return &cp
|
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
|
|
||||||
// stamping ttl so the response shares no memory with the cached slice.
|
|
||||||
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
|
|
||||||
out := make([]dns.RR, 0, len(records))
|
|
||||||
for _, rr := range records {
|
|
||||||
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
|
|
||||||
out = append(out, cp)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return out
|
|
||||||
}
|
|
||||||
|
|
||||||
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
|
|
||||||
// in answer, iterating until fixed point so out-of-order chains resolve.
|
|
||||||
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
|
|
||||||
owners := map[string]bool{dnsName: true}
|
|
||||||
for {
|
|
||||||
added := false
|
|
||||||
for _, rr := range answer {
|
|
||||||
cname, ok := rr.(*dns.CNAME)
|
|
||||||
if !ok {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
|
|
||||||
if !owners[name] {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
target := strings.ToLower(dns.Fqdn(cname.Target))
|
|
||||||
if !owners[target] {
|
|
||||||
owners[target] = true
|
|
||||||
added = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !added {
|
|
||||||
return owners
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
|
|
||||||
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
|
|
||||||
func resolveCacheTTL() time.Duration {
|
|
||||||
if v := os.Getenv(envMgmtCacheTTL); v != "" {
|
|
||||||
if d, err := time.ParseDuration(v); err == nil && d > 0 {
|
|
||||||
return d
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return defaultTTL
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,408 +0,0 @@
|
|||||||
package mgmt
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"errors"
|
|
||||||
"net"
|
|
||||||
"sync"
|
|
||||||
"sync/atomic"
|
|
||||||
"testing"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/dns/test"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
)
|
|
||||||
|
|
||||||
type fakeChain struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
calls map[string]int
|
|
||||||
answers map[string][]dns.RR
|
|
||||||
err error
|
|
||||||
hasRoot bool
|
|
||||||
onLookup func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func newFakeChain() *fakeChain {
|
|
||||||
return &fakeChain{
|
|
||||||
calls: map[string]int{},
|
|
||||||
answers: map[string][]dns.RR{},
|
|
||||||
hasRoot: true,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
return f.hasRoot
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
|
|
||||||
f.mu.Lock()
|
|
||||||
q := msg.Question[0]
|
|
||||||
key := q.Name + "|" + dns.TypeToString[q.Qtype]
|
|
||||||
f.calls[key]++
|
|
||||||
answers := f.answers[key]
|
|
||||||
err := f.err
|
|
||||||
onLookup := f.onLookup
|
|
||||||
f.mu.Unlock()
|
|
||||||
|
|
||||||
if onLookup != nil {
|
|
||||||
onLookup()
|
|
||||||
}
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
resp := &dns.Msg{}
|
|
||||||
resp.SetReply(msg)
|
|
||||||
resp.Answer = answers
|
|
||||||
return resp, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
key := name + "|" + dns.TypeToString[qtype]
|
|
||||||
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
|
|
||||||
switch qtype {
|
|
||||||
case dns.TypeA:
|
|
||||||
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
|
|
||||||
case dns.TypeAAAA:
|
|
||||||
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fakeChain) callCount(name string, qtype uint16) int {
|
|
||||||
f.mu.Lock()
|
|
||||||
defer f.mu.Unlock()
|
|
||||||
return f.calls[name+"|"+dns.TypeToString[qtype]]
|
|
||||||
}
|
|
||||||
|
|
||||||
// waitFor polls the predicate until it returns true or the deadline passes.
|
|
||||||
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
|
|
||||||
t.Helper()
|
|
||||||
deadline := time.Now().Add(d)
|
|
||||||
for time.Now().Before(deadline) {
|
|
||||||
if fn() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
time.Sleep(5 * time.Millisecond)
|
|
||||||
}
|
|
||||||
t.Fatalf("condition not met within %s", d)
|
|
||||||
}
|
|
||||||
|
|
||||||
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
|
|
||||||
t.Helper()
|
|
||||||
msg := new(dns.Msg)
|
|
||||||
msg.SetQuestion(name, dns.TypeA)
|
|
||||||
w := &test.MockResponseWriter{}
|
|
||||||
r.ServeDNS(w, msg)
|
|
||||||
return w.GetLastResponse()
|
|
||||||
}
|
|
||||||
|
|
||||||
func firstA(t *testing.T, resp *dns.Msg) string {
|
|
||||||
t.Helper()
|
|
||||||
require.NotNil(t, resp)
|
|
||||||
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
|
|
||||||
a, ok := resp.Answer[0].(*dns.A)
|
|
||||||
require.True(t, ok, "expected A record")
|
|
||||||
return a.A.String()
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
|
|
||||||
// Same cached entry age, different cacheTTL values: the shorter TTL must
|
|
||||||
// trigger a background refresh, the longer one must not. Proves that the
|
|
||||||
// per-Resolver cacheTTL field actually drives the stale decision.
|
|
||||||
cachedAt := time.Now().Add(-100 * time.Millisecond)
|
|
||||||
|
|
||||||
newRec := func() *cachedRecord {
|
|
||||||
return &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: cachedAt,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
|
|
||||||
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
r.cacheTTL = 10 * time.Millisecond
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
r.records[q] = newRec()
|
|
||||||
|
|
||||||
resp := queryA(t, r, q.Name)
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
|
||||||
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
return chain.callCount(q.Name, dns.TypeA) >= 1
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
r.cacheTTL = time.Hour
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
r.records[q] = newRec()
|
|
||||||
|
|
||||||
resp := queryA(t, r, q.Name)
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
|
||||||
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(), // fresh
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp))
|
|
||||||
|
|
||||||
time.Sleep(20 * time.Millisecond)
|
|
||||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
|
|
||||||
}
|
|
||||||
|
|
||||||
// First query: serves stale immediately.
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
|
|
||||||
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
|
|
||||||
})
|
|
||||||
|
|
||||||
// Next query should now return the refreshed IP.
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
|
|
||||||
var inflight atomic.Int32
|
|
||||||
var maxInflight atomic.Int32
|
|
||||||
chain.onLookup = func() {
|
|
||||||
cur := inflight.Add(1)
|
|
||||||
defer inflight.Add(-1)
|
|
||||||
for {
|
|
||||||
prev := maxInflight.Load()
|
|
||||||
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
|
|
||||||
}
|
|
||||||
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
|
||||||
}
|
|
||||||
|
|
||||||
var wg sync.WaitGroup
|
|
||||||
for i := 0; i < 50; i++ {
|
|
||||||
wg.Add(1)
|
|
||||||
go func() {
|
|
||||||
defer wg.Done()
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
}()
|
|
||||||
}
|
|
||||||
wg.Wait()
|
|
||||||
|
|
||||||
waitFor(t, 2*time.Second, func() bool {
|
|
||||||
return inflight.Load() == 0
|
|
||||||
})
|
|
||||||
|
|
||||||
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
|
|
||||||
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
|
|
||||||
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.err = errors.New("boom")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now().Add(-2 * defaultTTL),
|
|
||||||
}
|
|
||||||
|
|
||||||
// First stale hit triggers a refresh attempt that fails.
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
|
|
||||||
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
|
|
||||||
})
|
|
||||||
waitFor(t, time.Second, func() bool {
|
|
||||||
r.mutex.RLock()
|
|
||||||
defer r.mutex.RUnlock()
|
|
||||||
c, ok := r.records[q]
|
|
||||||
return ok && !c.lastFailedRefresh.IsZero()
|
|
||||||
})
|
|
||||||
|
|
||||||
// Subsequent stale hits within backoff window should not schedule more refreshes.
|
|
||||||
for i := 0; i < 10; i++ {
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
}
|
|
||||||
time.Sleep(50 * time.Millisecond)
|
|
||||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.hasRoot = false
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
// With hasRoot=false the chain must not be consulted. Use a short
|
|
||||||
// deadline so the OS fallback returns quickly without waiting on a
|
|
||||||
// real network call in CI.
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
|
|
||||||
defer cancel()
|
|
||||||
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
|
|
||||||
|
|
||||||
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
|
|
||||||
"chain must not be used when no root handler is registered at the bound priority")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
|
|
||||||
// ServeDNS being invoked for a question while a refresh for that question
|
|
||||||
// is inflight indicates a resolver loop (OS resolver sent the recursive
|
|
||||||
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
|
|
||||||
r := NewResolver()
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Simulate an inflight refresh.
|
|
||||||
r.markRefreshing(q)
|
|
||||||
defer r.clearRefreshing(q)
|
|
||||||
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
|
|
||||||
|
|
||||||
r.mutex.RLock()
|
|
||||||
inflight := r.refreshing[q]
|
|
||||||
r.mutex.RUnlock()
|
|
||||||
require.NotNil(t, inflight)
|
|
||||||
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
r.markRefreshing(q)
|
|
||||||
defer r.clearRefreshing(q)
|
|
||||||
|
|
||||||
// Multiple ServeDNS calls during the same refresh must not re-set the flag
|
|
||||||
// (CompareAndSwap from false -> true returns true only on the first call).
|
|
||||||
for range 5 {
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mutex.RLock()
|
|
||||||
inflight := r.refreshing[q]
|
|
||||||
r.mutex.RUnlock()
|
|
||||||
assert.True(t, inflight.Load())
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
|
|
||||||
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
|
|
||||||
r.records[q] = &cachedRecord{
|
|
||||||
records: []dns.RR{&dns.A{
|
|
||||||
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
|
|
||||||
A: net.ParseIP("10.0.0.1").To4(),
|
|
||||||
}},
|
|
||||||
cachedAt: time.Now(),
|
|
||||||
}
|
|
||||||
|
|
||||||
queryA(t, r, "mgmt.example.com.")
|
|
||||||
|
|
||||||
r.mutex.RLock()
|
|
||||||
_, ok := r.refreshing[q]
|
|
||||||
r.mutex.RUnlock()
|
|
||||||
assert.False(t, ok, "no refresh inflight means no loop tracking")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
|
|
||||||
r := NewResolver()
|
|
||||||
chain := newFakeChain()
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
|
|
||||||
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
|
|
||||||
r.SetChainResolver(chain, 50)
|
|
||||||
|
|
||||||
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
|
|
||||||
|
|
||||||
resp := queryA(t, r, "mgmt.example.com.")
|
|
||||||
assert.Equal(t, "10.0.0.2", firstA(t, resp))
|
|
||||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
|
|
||||||
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
|
|
||||||
}
|
|
||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
"net/url"
|
"net/url"
|
||||||
"strings"
|
"strings"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
|
|||||||
assert.False(t, resolver.MatchSubdomains())
|
assert.False(t, resolver.MatchSubdomains())
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestResolveCacheTTL(t *testing.T) {
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
value string
|
|
||||||
want time.Duration
|
|
||||||
}{
|
|
||||||
{"unset falls back to default", "", defaultTTL},
|
|
||||||
{"valid duration", "45s", 45 * time.Second},
|
|
||||||
{"valid minutes", "2m", 2 * time.Minute},
|
|
||||||
{"malformed falls back to default", "not-a-duration", defaultTTL},
|
|
||||||
{"zero falls back to default", "0s", defaultTTL},
|
|
||||||
{"negative falls back to default", "-5s", defaultTTL},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
t.Setenv(envMgmtCacheTTL, tc.value)
|
|
||||||
got := resolveCacheTTL()
|
|
||||||
assert.Equal(t, tc.want, got, "parsed TTL should match")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
|
|
||||||
t.Setenv(envMgmtCacheTTL, "7s")
|
|
||||||
r := NewResolver()
|
|
||||||
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ResponseTTL(t *testing.T) {
|
|
||||||
now := time.Now()
|
|
||||||
tests := []struct {
|
|
||||||
name string
|
|
||||||
cacheTTL time.Duration
|
|
||||||
cachedAt time.Time
|
|
||||||
wantMin uint32
|
|
||||||
wantMax uint32
|
|
||||||
}{
|
|
||||||
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
|
|
||||||
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
|
|
||||||
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
|
|
||||||
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range tests {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
r := &Resolver{cacheTTL: tc.cacheTTL}
|
|
||||||
got := r.responseTTL(tc.cachedAt)
|
|
||||||
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
|
|
||||||
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
func TestResolver_ExtractDomainFromURL(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -212,7 +212,6 @@ func newDefaultServer(
|
|||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
|
|
||||||
mgmtCacheResolver := mgmt.NewResolver()
|
mgmtCacheResolver := mgmt.NewResolver()
|
||||||
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
|
|
||||||
|
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
|
|||||||
@@ -28,7 +28,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
"github.com/netbirdio/netbird/client/firewall/firewalld"
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
@@ -69,7 +68,6 @@ import (
|
|||||||
signal "github.com/netbirdio/netbird/shared/signal/client"
|
signal "github.com/netbirdio/netbird/shared/signal/client"
|
||||||
sProto "github.com/netbirdio/netbird/shared/signal/proto"
|
sProto "github.com/netbirdio/netbird/shared/signal/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/netbirdio/netbird/util/capture"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
|
||||||
@@ -220,8 +218,6 @@ type Engine struct {
|
|||||||
portForwardManager *portforward.Manager
|
portForwardManager *portforward.Manager
|
||||||
srWatcher *guard.SRWatcher
|
srWatcher *guard.SRWatcher
|
||||||
|
|
||||||
afpacketCapture *capture.AFPacketCapture
|
|
||||||
|
|
||||||
// Sync response persistence (protected by syncRespMux)
|
// Sync response persistence (protected by syncRespMux)
|
||||||
syncRespMux sync.RWMutex
|
syncRespMux sync.RWMutex
|
||||||
persistSyncResponse bool
|
persistSyncResponse bool
|
||||||
@@ -948,12 +944,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
|||||||
return fmt.Errorf("update relay token: %w", err)
|
return fmt.Errorf("update relay token: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
urls := update.Urls
|
e.relayManager.UpdateServerURLs(update.Urls)
|
||||||
if override, ok := peer.OverrideRelayURLs(); ok {
|
|
||||||
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
|
|
||||||
urls = override
|
|
||||||
}
|
|
||||||
e.relayManager.UpdateServerURLs(urls)
|
|
||||||
|
|
||||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||||
@@ -1707,11 +1698,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) close() {
|
func (e *Engine) close() {
|
||||||
if e.afpacketCapture != nil {
|
|
||||||
e.afpacketCapture.Stop()
|
|
||||||
e.afpacketCapture = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
|
||||||
|
|
||||||
if e.wgInterface != nil {
|
if e.wgInterface != nil {
|
||||||
@@ -2177,62 +2163,6 @@ func (e *Engine) Address() (netip.Addr, error) {
|
|||||||
return e.wgInterface.Address().IP, nil
|
return e.wgInterface.Address().IP, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetCapture sets or clears packet capture on the WireGuard device.
|
|
||||||
// On userspace WireGuard, it taps the FilteredDevice directly.
|
|
||||||
// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture.
|
|
||||||
// Pass nil to disable capture.
|
|
||||||
func (e *Engine) SetCapture(pc device.PacketCapture) error {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
intf := e.wgInterface
|
|
||||||
if intf == nil {
|
|
||||||
return errors.New("wireguard interface not initialized")
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.afpacketCapture != nil {
|
|
||||||
e.afpacketCapture.Stop()
|
|
||||||
e.afpacketCapture = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
dev := intf.GetDevice()
|
|
||||||
if dev != nil {
|
|
||||||
dev.SetCapture(pc)
|
|
||||||
e.setForwarderCapture(pc)
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Kernel mode: no FilteredDevice. Use AF_PACKET on Linux.
|
|
||||||
if pc == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
sess, ok := pc.(*capture.Session)
|
|
||||||
if !ok {
|
|
||||||
return errors.New("filtered device not available and AF_PACKET requires *capture.Session")
|
|
||||||
}
|
|
||||||
|
|
||||||
afc := capture.NewAFPacketCapture(intf.Name(), sess)
|
|
||||||
if err := afc.Start(); err != nil {
|
|
||||||
return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err)
|
|
||||||
}
|
|
||||||
e.afpacketCapture = afc
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// setForwarderCapture propagates capture to the USP filter's forwarder endpoint.
|
|
||||||
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
|
|
||||||
func (e *Engine) setForwarderCapture(pc device.PacketCapture) {
|
|
||||||
if e.firewall == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
type forwarderCapturer interface {
|
|
||||||
SetPacketCapture(pc forwarder.PacketCapture)
|
|
||||||
}
|
|
||||||
if fc, ok := e.firewall.(forwarderCapturer); ok {
|
|
||||||
fc.SetPacketCapture(pc)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
|
||||||
if e.firewall == nil {
|
if e.firewall == nil {
|
||||||
log.Warn("firewall is disabled, not updating forwarding rules")
|
log.Warn("firewall is disabled, not updating forwarding rules")
|
||||||
@@ -2454,8 +2384,6 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP())
|
|
||||||
|
|
||||||
offerAnswer := peer.OfferAnswer{
|
offerAnswer := peer.OfferAnswer{
|
||||||
IceCredentials: peer.IceCredentials{
|
IceCredentials: peer.IceCredentials{
|
||||||
UFrag: remoteCred.UFrag,
|
UFrag: remoteCred.UFrag,
|
||||||
@@ -2466,23 +2394,7 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
|
|||||||
RosenpassPubKey: rosenpassPubKey,
|
RosenpassPubKey: rosenpassPubKey,
|
||||||
RosenpassAddr: rosenpassAddr,
|
RosenpassAddr: rosenpassAddr,
|
||||||
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
|
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
|
||||||
RelaySrvIP: relayIP,
|
|
||||||
SessionID: sessionID,
|
SessionID: sessionID,
|
||||||
}
|
}
|
||||||
return &offerAnswer, nil
|
return &offerAnswer, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a
|
|
||||||
// netip.Addr. Returns the zero value for empty input and logs a warning
|
|
||||||
// for malformed payloads.
|
|
||||||
func decodeRelayIP(b []byte) netip.Addr {
|
|
||||||
if len(b) == 0 {
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
ip, ok := netip.AddrFromSlice(b)
|
|
||||||
if !ok {
|
|
||||||
log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b))
|
|
||||||
return netip.Addr{}
|
|
||||||
}
|
|
||||||
return ip.Unmap()
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package activity
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -17,6 +18,10 @@ import (
|
|||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func isBindListenerPlatform() bool {
|
||||||
|
return runtime.GOOS == "windows" || runtime.GOOS == "js"
|
||||||
|
}
|
||||||
|
|
||||||
// mockEndpointManager implements device.EndpointManager for testing
|
// mockEndpointManager implements device.EndpointManager for testing
|
||||||
type mockEndpointManager struct {
|
type mockEndpointManager struct {
|
||||||
endpoints map[netip.Addr]net.Conn
|
endpoints map[netip.Addr]net.Conn
|
||||||
@@ -176,6 +181,10 @@ func TestBindListener_Close(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_BindMode(t *testing.T) {
|
func TestManager_BindMode(t *testing.T) {
|
||||||
|
if !isBindListenerPlatform() {
|
||||||
|
t.Skip("BindListener only used on Windows/JS platforms")
|
||||||
|
}
|
||||||
|
|
||||||
mockEndpointMgr := newMockEndpointManager()
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
|
|
||||||
@@ -217,6 +226,10 @@ func TestManager_BindMode(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
func TestManager_BindMode_MultiplePeers(t *testing.T) {
|
||||||
|
if !isBindListenerPlatform() {
|
||||||
|
t.Skip("BindListener only used on Windows/JS platforms")
|
||||||
|
}
|
||||||
|
|
||||||
mockEndpointMgr := newMockEndpointManager()
|
mockEndpointMgr := newMockEndpointManager()
|
||||||
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
|
||||||
|
|
||||||
|
|||||||
@@ -4,12 +4,14 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
@@ -73,6 +75,16 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
|||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// BindListener is used on Windows, JS, and netstack platforms:
|
||||||
|
// - JS: Cannot listen to UDP sockets
|
||||||
|
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||||
|
// gateway points to, preventing them from reaching the loopback interface.
|
||||||
|
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||||
|
// BindListener bypasses these issues by passing data directly through the bind.
|
||||||
|
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||||
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
provider, ok := m.wgIface.(bindProvider)
|
provider, ok := m.wgIface.(bindProvider)
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
|
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
"github.com/netbirdio/netbird/client/internal/lazyconn/activity"
|
||||||
@@ -90,8 +91,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
|
|||||||
m.routesMu.Lock()
|
m.routesMu.Lock()
|
||||||
defer m.routesMu.Unlock()
|
defer m.routesMu.Unlock()
|
||||||
|
|
||||||
clear(m.peerToHAGroups)
|
maps.Clear(m.peerToHAGroups)
|
||||||
clear(m.haGroupToPeers)
|
maps.Clear(m.haGroupToPeers)
|
||||||
|
|
||||||
for haUniqueID, routes := range haMap {
|
for haUniqueID, routes := range haMap {
|
||||||
var peers []string
|
var peers []string
|
||||||
|
|||||||
@@ -3,6 +3,8 @@ package store
|
|||||||
import (
|
import (
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
"github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
@@ -28,7 +30,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
|
|||||||
func (m *Memory) Close() {
|
func (m *Memory) Close() {
|
||||||
m.mux.Lock()
|
m.mux.Lock()
|
||||||
defer m.mux.Unlock()
|
defer m.mux.Unlock()
|
||||||
clear(m.events)
|
maps.Clear(m.events)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Memory) GetEvents() []*types.Event {
|
func (m *Memory) GetEvents() []*types.Event {
|
||||||
|
|||||||
@@ -7,8 +7,7 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
EnvKeyNBForceRelay = "NB_FORCE_RELAY"
|
||||||
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func IsForceRelayed() bool {
|
func IsForceRelayed() bool {
|
||||||
@@ -17,28 +16,3 @@ func IsForceRelayed() bool {
|
|||||||
}
|
}
|
||||||
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
|
||||||
}
|
}
|
||||||
|
|
||||||
// OverrideRelayURLs returns the relay server URL list set in
|
|
||||||
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
|
|
||||||
// the override is active. When the env var is unset, the boolean is false
|
|
||||||
// and the caller should keep the list received from the management server.
|
|
||||||
// Intended for lab/debug scenarios where a peer must pin to a specific home
|
|
||||||
// relay regardless of what management offers.
|
|
||||||
func OverrideRelayURLs() ([]string, bool) {
|
|
||||||
raw := os.Getenv(EnvKeyNBHomeRelayServers)
|
|
||||||
if raw == "" {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
parts := strings.Split(raw, ",")
|
|
||||||
urls := make([]string, 0, len(parts))
|
|
||||||
for _, p := range parts {
|
|
||||||
p = strings.TrimSpace(p)
|
|
||||||
if p != "" {
|
|
||||||
urls = append(urls, p)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if len(urls) == 0 {
|
|
||||||
return nil, false
|
|
||||||
}
|
|
||||||
return urls, true
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package peer
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net/netip"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@@ -41,10 +40,6 @@ type OfferAnswer struct {
|
|||||||
|
|
||||||
// relay server address
|
// relay server address
|
||||||
RelaySrvAddress string
|
RelaySrvAddress string
|
||||||
// RelaySrvIP is the IP the remote peer is connected to on its
|
|
||||||
// relay server. Used as a dial target if DNS for RelaySrvAddress
|
|
||||||
// fails. Zero value if the peer did not advertise an IP.
|
|
||||||
RelaySrvIP netip.Addr
|
|
||||||
// SessionID is the unique identifier of the session, used to discard old messages
|
// SessionID is the unique identifier of the session, used to discard old messages
|
||||||
SessionID *ICESessionID
|
SessionID *ICESessionID
|
||||||
}
|
}
|
||||||
@@ -222,9 +217,8 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
|
|||||||
answer.SessionID = &sid
|
answer.SessionID = &sid
|
||||||
}
|
}
|
||||||
|
|
||||||
if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil {
|
if addr, err := h.relay.RelayInstanceAddress(); err == nil {
|
||||||
answer.RelaySrvAddress = addr
|
answer.RelaySrvAddress = addr
|
||||||
answer.RelaySrvIP = ip
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return answer
|
return answer
|
||||||
|
|||||||
@@ -8,7 +8,6 @@ import (
|
|||||||
type mocListener struct {
|
type mocListener struct {
|
||||||
lastState int
|
lastState int
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
peersWg sync.WaitGroup
|
|
||||||
peers int
|
peers int
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -34,7 +33,6 @@ func (l *mocListener) OnAddressChanged(host, addr string) {
|
|||||||
}
|
}
|
||||||
func (l *mocListener) OnPeersListChanged(size int) {
|
func (l *mocListener) OnPeersListChanged(size int) {
|
||||||
l.peers = size
|
l.peers = size
|
||||||
l.peersWg.Done()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *mocListener) setWaiter() {
|
func (l *mocListener) setWaiter() {
|
||||||
@@ -45,14 +43,6 @@ func (l *mocListener) wait() {
|
|||||||
l.wg.Wait()
|
l.wg.Wait()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *mocListener) setPeersWaiter() {
|
|
||||||
l.peersWg.Add(1)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *mocListener) waitPeers() {
|
|
||||||
l.peersWg.Wait()
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_notifier_serverState(t *testing.T) {
|
func Test_notifier_serverState(t *testing.T) {
|
||||||
|
|
||||||
type scenario struct {
|
type scenario struct {
|
||||||
@@ -82,13 +72,11 @@ func Test_notifier_serverState(t *testing.T) {
|
|||||||
func Test_notifier_SetListener(t *testing.T) {
|
func Test_notifier_SetListener(t *testing.T) {
|
||||||
listener := &mocListener{}
|
listener := &mocListener{}
|
||||||
listener.setWaiter()
|
listener.setWaiter()
|
||||||
listener.setPeersWaiter()
|
|
||||||
|
|
||||||
n := newNotifier()
|
n := newNotifier()
|
||||||
n.lastNotification = stateConnecting
|
n.lastNotification = stateConnecting
|
||||||
n.setListener(listener)
|
n.setListener(listener)
|
||||||
listener.wait()
|
listener.wait()
|
||||||
listener.waitPeers()
|
|
||||||
if listener.lastState != n.lastNotification {
|
if listener.lastState != n.lastNotification {
|
||||||
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
|
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
|
||||||
}
|
}
|
||||||
@@ -97,14 +85,9 @@ func Test_notifier_SetListener(t *testing.T) {
|
|||||||
func Test_notifier_RemoveListener(t *testing.T) {
|
func Test_notifier_RemoveListener(t *testing.T) {
|
||||||
listener := &mocListener{}
|
listener := &mocListener{}
|
||||||
listener.setWaiter()
|
listener.setWaiter()
|
||||||
listener.setPeersWaiter()
|
|
||||||
n := newNotifier()
|
n := newNotifier()
|
||||||
n.lastNotification = stateConnecting
|
n.lastNotification = stateConnecting
|
||||||
n.setListener(listener)
|
n.setListener(listener)
|
||||||
// setListener replays cached state on a goroutine; wait for both the state
|
|
||||||
// and peers callbacks to finish so we don't race on listener.peers.
|
|
||||||
listener.wait()
|
|
||||||
listener.waitPeers()
|
|
||||||
n.removeListener()
|
n.removeListener()
|
||||||
n.peerListChanged(1)
|
n.peerListChanged(1)
|
||||||
|
|
||||||
|
|||||||
@@ -54,19 +54,19 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
|
|||||||
log.Warnf("failed to get session ID bytes: %v", err)
|
log.Warnf("failed to get session ID bytes: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{
|
msg, err := signal.MarshalCredential(
|
||||||
Type: bodyType,
|
s.wgPrivateKey,
|
||||||
WgListenPort: offerAnswer.WgListenPort,
|
offerAnswer.WgListenPort,
|
||||||
Credential: &signal.Credential{
|
remoteKey,
|
||||||
|
&signal.Credential{
|
||||||
UFrag: offerAnswer.IceCredentials.UFrag,
|
UFrag: offerAnswer.IceCredentials.UFrag,
|
||||||
Pwd: offerAnswer.IceCredentials.Pwd,
|
Pwd: offerAnswer.IceCredentials.Pwd,
|
||||||
},
|
},
|
||||||
RosenpassPubKey: offerAnswer.RosenpassPubKey,
|
bodyType,
|
||||||
RosenpassAddr: offerAnswer.RosenpassAddr,
|
offerAnswer.RosenpassPubKey,
|
||||||
RelaySrvAddress: offerAnswer.RelaySrvAddress,
|
offerAnswer.RosenpassAddr,
|
||||||
RelaySrvIP: offerAnswer.RelaySrvIP,
|
offerAnswer.RelaySrvAddress,
|
||||||
SessionID: sessionIDBytes,
|
sessionIDBytes)
|
||||||
})
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error {
|
|||||||
// UpdatePeerState updates peer status
|
// UpdatePeerState updates peer status
|
||||||
func (d *Status) UpdatePeerState(receivedState State) error {
|
func (d *Status) UpdatePeerState(receivedState State) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
peerState, ok := d.peers[receivedState.PubKey]
|
peerState, ok := d.peers[receivedState.PubKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
d.mux.Unlock()
|
|
||||||
return errors.New("peer doesn't exist")
|
return errors.New("peer doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -343,29 +343,23 @@ func (d *Status) UpdatePeerState(receivedState State) error {
|
|||||||
|
|
||||||
d.peers[receivedState.PubKey] = peerState
|
d.peers[receivedState.PubKey] = peerState
|
||||||
|
|
||||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||||
// when we close the connection we will not notify the router manager
|
d.notifyPeerListChanged()
|
||||||
notifyRouter := receivedState.ConnStatus == StatusIdle
|
|
||||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
|
||||||
numPeers := d.numOfPeers()
|
|
||||||
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
if notifyList {
|
|
||||||
d.notifier.peerListChanged(numPeers)
|
|
||||||
}
|
}
|
||||||
if notifyRouter {
|
|
||||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
// when we close the connection we will not notify the router manager
|
||||||
|
if receivedState.ConnStatus == StatusIdle {
|
||||||
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
|
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
peerState, ok := d.peers[peer]
|
peerState, ok := d.peers[peer]
|
||||||
if !ok {
|
if !ok {
|
||||||
d.mux.Unlock()
|
|
||||||
return errors.New("peer doesn't exist")
|
return errors.New("peer doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -377,20 +371,17 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
|
|||||||
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
|
d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
|
||||||
}
|
}
|
||||||
|
|
||||||
numPeers := d.numOfPeers()
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
// todo: consider to make sense of this notification or not
|
// todo: consider to make sense of this notification or not
|
||||||
d.notifier.peerListChanged(numPeers)
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
peerState, ok := d.peers[peer]
|
peerState, ok := d.peers[peer]
|
||||||
if !ok {
|
if !ok {
|
||||||
d.mux.Unlock()
|
|
||||||
return errors.New("peer doesn't exist")
|
return errors.New("peer doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -402,11 +393,8 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
|
|||||||
d.routeIDLookup.RemoveRemoteRouteID(pref)
|
d.routeIDLookup.RemoveRemoteRouteID(pref)
|
||||||
}
|
}
|
||||||
|
|
||||||
numPeers := d.numOfPeers()
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
// todo: consider to make sense of this notification or not
|
// todo: consider to make sense of this notification or not
|
||||||
d.notifier.peerListChanged(numPeers)
|
d.notifyPeerListChanged()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -422,10 +410,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
|
|||||||
|
|
||||||
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
func (d *Status) UpdatePeerICEState(receivedState State) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
peerState, ok := d.peers[receivedState.PubKey]
|
peerState, ok := d.peers[receivedState.PubKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
d.mux.Unlock()
|
|
||||||
return errors.New("peer doesn't exist")
|
return errors.New("peer doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -443,28 +431,22 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
|
|||||||
|
|
||||||
d.peers[receivedState.PubKey] = peerState
|
d.peers[receivedState.PubKey] = peerState
|
||||||
|
|
||||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
d.notifyPeerListChanged()
|
||||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
|
||||||
numPeers := d.numOfPeers()
|
|
||||||
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
if notifyList {
|
|
||||||
d.notifier.peerListChanged(numPeers)
|
|
||||||
}
|
}
|
||||||
if notifyRouter {
|
|
||||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||||
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
peerState, ok := d.peers[receivedState.PubKey]
|
peerState, ok := d.peers[receivedState.PubKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
d.mux.Unlock()
|
|
||||||
return errors.New("peer doesn't exist")
|
return errors.New("peer doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,28 +461,22 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
|
|||||||
|
|
||||||
d.peers[receivedState.PubKey] = peerState
|
d.peers[receivedState.PubKey] = peerState
|
||||||
|
|
||||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
d.notifyPeerListChanged()
|
||||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
|
||||||
numPeers := d.numOfPeers()
|
|
||||||
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
if notifyList {
|
|
||||||
d.notifier.peerListChanged(numPeers)
|
|
||||||
}
|
}
|
||||||
if notifyRouter {
|
|
||||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||||
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
|
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
peerState, ok := d.peers[receivedState.PubKey]
|
peerState, ok := d.peers[receivedState.PubKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
d.mux.Unlock()
|
|
||||||
return errors.New("peer doesn't exist")
|
return errors.New("peer doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -514,28 +490,22 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
|
|||||||
|
|
||||||
d.peers[receivedState.PubKey] = peerState
|
d.peers[receivedState.PubKey] = peerState
|
||||||
|
|
||||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
d.notifyPeerListChanged()
|
||||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
|
||||||
numPeers := d.numOfPeers()
|
|
||||||
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
if notifyList {
|
|
||||||
d.notifier.peerListChanged(numPeers)
|
|
||||||
}
|
}
|
||||||
if notifyRouter {
|
|
||||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||||
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
peerState, ok := d.peers[receivedState.PubKey]
|
peerState, ok := d.peers[receivedState.PubKey]
|
||||||
if !ok {
|
if !ok {
|
||||||
d.mux.Unlock()
|
|
||||||
return errors.New("peer doesn't exist")
|
return errors.New("peer doesn't exist")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -552,18 +522,12 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
|
|||||||
|
|
||||||
d.peers[receivedState.PubKey] = peerState
|
d.peers[receivedState.PubKey] = peerState
|
||||||
|
|
||||||
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus)
|
if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
|
||||||
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed)
|
d.notifyPeerListChanged()
|
||||||
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
|
|
||||||
numPeers := d.numOfPeers()
|
|
||||||
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
if notifyList {
|
|
||||||
d.notifier.peerListChanged(numPeers)
|
|
||||||
}
|
}
|
||||||
if notifyRouter {
|
|
||||||
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot)
|
if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
|
||||||
|
d.notifyPeerStateChangeListeners(receivedState.PubKey)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -630,33 +594,17 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
|
|||||||
// FinishPeerListModifications this event invoke the notification
|
// FinishPeerListModifications this event invoke the notification
|
||||||
func (d *Status) FinishPeerListModifications() {
|
func (d *Status) FinishPeerListModifications() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
|
||||||
if !d.peerListChangedForNotification {
|
if !d.peerListChangedForNotification {
|
||||||
d.mux.Unlock()
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
d.peerListChangedForNotification = false
|
d.peerListChangedForNotification = false
|
||||||
|
|
||||||
numPeers := d.numOfPeers()
|
d.notifyPeerListChanged()
|
||||||
|
|
||||||
// snapshot per-peer router state to deliver after the lock is released
|
|
||||||
type routerDispatch struct {
|
|
||||||
peerID string
|
|
||||||
snapshot map[string]RouterState
|
|
||||||
}
|
|
||||||
dispatches := make([]routerDispatch, 0, len(d.peers))
|
|
||||||
for key := range d.peers {
|
for key := range d.peers {
|
||||||
snapshot := d.snapshotRouterPeersLocked(key, true)
|
d.notifyPeerStateChangeListeners(key)
|
||||||
if snapshot != nil {
|
|
||||||
dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
d.notifier.peerListChanged(numPeers)
|
|
||||||
for _, rd := range dispatches {
|
|
||||||
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -707,12 +655,10 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
|
|||||||
// UpdateLocalPeerState updates local peer status
|
// UpdateLocalPeerState updates local peer status
|
||||||
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
d.localPeer = localPeerState
|
defer d.mux.Unlock()
|
||||||
fqdn := d.localPeer.FQDN
|
|
||||||
ip := d.localPeer.IP
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
d.notifier.localAddressChanged(fqdn, ip)
|
d.localPeer = localPeerState
|
||||||
|
d.notifyAddressChanged()
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddLocalPeerStateRoute adds a route to the local peer state
|
// AddLocalPeerStateRoute adds a route to the local peer state
|
||||||
@@ -775,36 +721,30 @@ func (d *Status) CleanLocalPeerStateRoutes() {
|
|||||||
// CleanLocalPeerState cleans local peer status
|
// CleanLocalPeerState cleans local peer status
|
||||||
func (d *Status) CleanLocalPeerState() {
|
func (d *Status) CleanLocalPeerState() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
d.localPeer = LocalPeerState{}
|
defer d.mux.Unlock()
|
||||||
fqdn := d.localPeer.FQDN
|
|
||||||
ip := d.localPeer.IP
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
d.notifier.localAddressChanged(fqdn, ip)
|
d.localPeer = LocalPeerState{}
|
||||||
|
d.notifyAddressChanged()
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkManagementDisconnected sets ManagementState to disconnected
|
// MarkManagementDisconnected sets ManagementState to disconnected
|
||||||
func (d *Status) MarkManagementDisconnected(err error) {
|
func (d *Status) MarkManagementDisconnected(err error) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.managementState = false
|
d.managementState = false
|
||||||
d.managementError = err
|
d.managementError = err
|
||||||
mgm := d.managementState
|
|
||||||
sig := d.signalState
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
d.notifier.updateServerStates(mgm, sig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkManagementConnected sets ManagementState to connected
|
// MarkManagementConnected sets ManagementState to connected
|
||||||
func (d *Status) MarkManagementConnected() {
|
func (d *Status) MarkManagementConnected() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.managementState = true
|
d.managementState = true
|
||||||
d.managementError = nil
|
d.managementError = nil
|
||||||
mgm := d.managementState
|
|
||||||
sig := d.signalState
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
d.notifier.updateServerStates(mgm, sig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateSignalAddress update the address of the signal server
|
// UpdateSignalAddress update the address of the signal server
|
||||||
@@ -838,25 +778,21 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
|
|||||||
// MarkSignalDisconnected sets SignalState to disconnected
|
// MarkSignalDisconnected sets SignalState to disconnected
|
||||||
func (d *Status) MarkSignalDisconnected(err error) {
|
func (d *Status) MarkSignalDisconnected(err error) {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.signalState = false
|
d.signalState = false
|
||||||
d.signalError = err
|
d.signalError = err
|
||||||
mgm := d.managementState
|
|
||||||
sig := d.signalState
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
d.notifier.updateServerStates(mgm, sig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarkSignalConnected sets SignalState to connected
|
// MarkSignalConnected sets SignalState to connected
|
||||||
func (d *Status) MarkSignalConnected() {
|
func (d *Status) MarkSignalConnected() {
|
||||||
d.mux.Lock()
|
d.mux.Lock()
|
||||||
|
defer d.mux.Unlock()
|
||||||
|
defer d.onConnectionChanged()
|
||||||
|
|
||||||
d.signalState = true
|
d.signalState = true
|
||||||
d.signalError = nil
|
d.signalError = nil
|
||||||
mgm := d.managementState
|
|
||||||
sig := d.signalState
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
d.notifier.updateServerStates(mgm, sig)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
|
||||||
@@ -983,7 +919,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
|
|||||||
|
|
||||||
// if the server connection is not established then we will use the general address
|
// if the server connection is not established then we will use the general address
|
||||||
// in case of connection we will use the instance specific address
|
// in case of connection we will use the instance specific address
|
||||||
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress()
|
instanceAddr, err := d.relayMgr.RelayInstanceAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
// TODO add their status
|
// TODO add their status
|
||||||
for _, r := range d.relayMgr.ServerURLs() {
|
for _, r := range d.relayMgr.ServerURLs() {
|
||||||
@@ -1076,17 +1012,18 @@ func (d *Status) RemoveConnectionListener() {
|
|||||||
d.notifier.removeListener()
|
d.notifier.removeListener()
|
||||||
}
|
}
|
||||||
|
|
||||||
// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers.
|
func (d *Status) onConnectionChanged() {
|
||||||
// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID
|
d.notifier.updateServerStates(d.managementState, d.signalState)
|
||||||
// or when notify is false. The snapshot is consumed later by dispatchRouterPeers
|
}
|
||||||
// outside the lock so the channel send cannot stall any d.mux holder.
|
|
||||||
func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState {
|
// notifyPeerStateChangeListeners notifies route manager about the change in peer state
|
||||||
if !notify {
|
func (d *Status) notifyPeerStateChangeListeners(peerID string) {
|
||||||
return nil
|
subs, ok := d.changeNotify[peerID]
|
||||||
}
|
if !ok {
|
||||||
if _, ok := d.changeNotify[peerID]; !ok {
|
return
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// collect the relevant data for router peers
|
||||||
routerPeers := make(map[string]RouterState, len(d.changeNotify))
|
routerPeers := make(map[string]RouterState, len(d.changeNotify))
|
||||||
for pid := range d.changeNotify {
|
for pid := range d.changeNotify {
|
||||||
s, ok := d.peers[pid]
|
s, ok := d.peers[pid]
|
||||||
@@ -1094,35 +1031,13 @@ func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[strin
|
|||||||
log.Warnf("router peer not found in peers list: %s", pid)
|
log.Warnf("router peer not found in peers list: %s", pid)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
routerPeers[pid] = RouterState{
|
routerPeers[pid] = RouterState{
|
||||||
Status: s.ConnStatus,
|
Status: s.ConnStatus,
|
||||||
Relayed: s.Relayed,
|
Relayed: s.Relayed,
|
||||||
Latency: s.Latency,
|
Latency: s.Latency,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return routerPeers
|
|
||||||
}
|
|
||||||
|
|
||||||
// dispatchRouterPeers delivers a previously snapshotted router-state map to
|
|
||||||
// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a
|
|
||||||
// fresh, short read of d.changeNotify under the lock to grab subscriber
|
|
||||||
// channels, then sends outside the lock so a slow consumer cannot block other
|
|
||||||
// d.mux holders. The send itself stays blocking (only short-circuited by the
|
|
||||||
// subscriber's context) so peer state transitions are not silently dropped.
|
|
||||||
func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) {
|
|
||||||
if routerPeers == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d.mux.Lock()
|
|
||||||
subsMap, ok := d.changeNotify[peerID]
|
|
||||||
subs := make([]*StatusChangeSubscription, 0, len(subsMap))
|
|
||||||
if ok {
|
|
||||||
for _, sub := range subsMap {
|
|
||||||
subs = append(subs, sub)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
d.mux.Unlock()
|
|
||||||
|
|
||||||
for _, sub := range subs {
|
for _, sub := range subs {
|
||||||
select {
|
select {
|
||||||
@@ -1132,6 +1047,14 @@ func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]Route
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (d *Status) notifyPeerListChanged() {
|
||||||
|
d.notifier.peerListChanged(d.numOfPeers())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Status) notifyAddressChanged() {
|
||||||
|
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
|
||||||
|
}
|
||||||
|
|
||||||
func (d *Status) numOfPeers() int {
|
func (d *Status) numOfPeers() int {
|
||||||
return len(d.peers) + len(d.offlinePeers)
|
return len(d.peers) + len(d.offlinePeers)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,6 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
|
||||||
@@ -54,19 +53,15 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
w.relaySupportedOnRemotePeer.Store(true)
|
w.relaySupportedOnRemotePeer.Store(true)
|
||||||
|
|
||||||
// the relayManager will return with error in case if the connection has lost with relay server
|
// the relayManager will return with error in case if the connection has lost with relay server
|
||||||
currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress()
|
currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
w.log.Errorf("failed to handle new offer: %s", err)
|
w.log.Errorf("failed to handle new offer: %s", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
|
||||||
var serverIP netip.Addr
|
|
||||||
if srv == remoteOfferAnswer.RelaySrvAddress {
|
|
||||||
serverIP = remoteOfferAnswer.RelaySrvIP
|
|
||||||
}
|
|
||||||
|
|
||||||
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP)
|
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
if errors.Is(err, relayClient.ErrConnAlreadyExists) {
|
||||||
w.log.Debugf("handled offer by reusing existing relay connection")
|
w.log.Debugf("handled offer by reusing existing relay connection")
|
||||||
@@ -95,7 +90,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) {
|
func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
|
||||||
return w.relayManager.RelayInstanceAddress()
|
return w.relayManager.RelayInstanceAddress()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -89,16 +89,8 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
|||||||
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
|
||||||
}
|
}
|
||||||
|
|
||||||
reused := false
|
|
||||||
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
if err := r.addScopedDefault(unspec, nexthop); err != nil {
|
||||||
if !errors.Is(err, unix.EEXIST) {
|
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
||||||
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
|
|
||||||
}
|
|
||||||
// macOS installs its own RTF_IFSCOPE defaults for primary service
|
|
||||||
// selection on multi-NIC setups, so a route on this ifindex can
|
|
||||||
// already exist before we try. Binding to it via IP[V6]_BOUND_IF
|
|
||||||
// still produces the scoped lookup we need.
|
|
||||||
reused = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
af := unix.AF_INET
|
af := unix.AF_INET
|
||||||
@@ -110,11 +102,7 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
|
|||||||
if nexthop.IP.IsValid() {
|
if nexthop.IP.IsValid() {
|
||||||
via = nexthop.IP.String()
|
via = nexthop.IP.String()
|
||||||
}
|
}
|
||||||
verb := "installed"
|
log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
|
||||||
if reused {
|
|
||||||
verb = "reused existing"
|
|
||||||
}
|
|
||||||
log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec))
|
|
||||||
return true, nil
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
@@ -43,8 +44,8 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
|
|||||||
if rs.selectedRoutes == nil {
|
if rs.selectedRoutes == nil {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
clear(rs.deselectedRoutes)
|
maps.Clear(rs.deselectedRoutes)
|
||||||
clear(rs.selectedRoutes)
|
maps.Clear(rs.selectedRoutes)
|
||||||
for _, r := range allRoutes {
|
for _, r := range allRoutes {
|
||||||
rs.deselectedRoutes[r] = struct{}{}
|
rs.deselectedRoutes[r] = struct{}{}
|
||||||
}
|
}
|
||||||
@@ -77,8 +78,8 @@ func (rs *RouteSelector) SelectAllRoutes() {
|
|||||||
if rs.selectedRoutes == nil {
|
if rs.selectedRoutes == nil {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
clear(rs.deselectedRoutes)
|
maps.Clear(rs.deselectedRoutes)
|
||||||
clear(rs.selectedRoutes)
|
maps.Clear(rs.selectedRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeselectRoutes removes specific routes from the selection.
|
// DeselectRoutes removes specific routes from the selection.
|
||||||
@@ -115,8 +116,8 @@ func (rs *RouteSelector) DeselectAllRoutes() {
|
|||||||
if rs.selectedRoutes == nil {
|
if rs.selectedRoutes == nil {
|
||||||
rs.selectedRoutes = map[route.NetID]struct{}{}
|
rs.selectedRoutes = map[route.NetID]struct{}{}
|
||||||
}
|
}
|
||||||
clear(rs.deselectedRoutes)
|
maps.Clear(rs.deselectedRoutes)
|
||||||
clear(rs.selectedRoutes)
|
maps.Clear(rs.selectedRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsSelected checks if a specific route is selected.
|
// IsSelected checks if a specific route is selected.
|
||||||
|
|||||||
@@ -2,358 +2,217 @@
|
|||||||
|
|
||||||
package sleep
|
package sleep
|
||||||
|
|
||||||
|
/*
|
||||||
|
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
|
||||||
|
#include <IOKit/pwr_mgt/IOPMLib.h>
|
||||||
|
#include <IOKit/IOMessage.h>
|
||||||
|
#include <CoreFoundation/CoreFoundation.h>
|
||||||
|
|
||||||
|
extern void sleepCallbackBridge();
|
||||||
|
extern void poweredOnCallbackBridge();
|
||||||
|
extern void suspendedCallbackBridge();
|
||||||
|
extern void resumedCallbackBridge();
|
||||||
|
|
||||||
|
|
||||||
|
// C global variables for IOKit state
|
||||||
|
static IONotificationPortRef g_notifyPortRef = NULL;
|
||||||
|
static io_object_t g_notifierObject = 0;
|
||||||
|
static io_object_t g_generalInterestNotifier = 0;
|
||||||
|
static io_connect_t g_rootPort = 0;
|
||||||
|
static CFRunLoopRef g_runLoop = NULL;
|
||||||
|
|
||||||
|
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
|
||||||
|
switch (messageType) {
|
||||||
|
case kIOMessageSystemWillSleep:
|
||||||
|
sleepCallbackBridge();
|
||||||
|
IOAllowPowerChange(g_rootPort, (long)messageArgument);
|
||||||
|
break;
|
||||||
|
case kIOMessageSystemHasPoweredOn:
|
||||||
|
poweredOnCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsSuspended:
|
||||||
|
suspendedCallbackBridge();
|
||||||
|
break;
|
||||||
|
case kIOMessageServiceIsResumed:
|
||||||
|
resumedCallbackBridge();
|
||||||
|
break;
|
||||||
|
default:
|
||||||
|
break;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
static void registerNotifications() {
|
||||||
|
g_rootPort = IORegisterForSystemPower(
|
||||||
|
NULL,
|
||||||
|
&g_notifyPortRef,
|
||||||
|
(IOServiceInterestCallback)sleepCallback,
|
||||||
|
&g_notifierObject
|
||||||
|
);
|
||||||
|
|
||||||
|
if (g_rootPort == 0) {
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
CFRunLoopAddSource(CFRunLoopGetCurrent(),
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
g_runLoop = CFRunLoopGetCurrent();
|
||||||
|
CFRunLoopRun();
|
||||||
|
}
|
||||||
|
|
||||||
|
static void unregisterNotifications() {
|
||||||
|
CFRunLoopRemoveSource(g_runLoop,
|
||||||
|
IONotificationPortGetRunLoopSource(g_notifyPortRef),
|
||||||
|
kCFRunLoopCommonModes);
|
||||||
|
|
||||||
|
IODeregisterForSystemPower(&g_notifierObject);
|
||||||
|
IOServiceClose(g_rootPort);
|
||||||
|
IONotificationPortDestroy(g_notifyPortRef);
|
||||||
|
CFRunLoopStop(g_runLoop);
|
||||||
|
|
||||||
|
g_notifyPortRef = NULL;
|
||||||
|
g_notifierObject = 0;
|
||||||
|
g_rootPort = 0;
|
||||||
|
g_runLoop = NULL;
|
||||||
|
}
|
||||||
|
|
||||||
|
*/
|
||||||
|
import "C"
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"runtime"
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
"unsafe"
|
|
||||||
|
|
||||||
"github.com/ebitengine/purego"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IOKit message types from IOKit/IOMessage.h.
|
|
||||||
const (
|
|
||||||
kIOMessageCanSystemSleep uintptr = 0xe0000270
|
|
||||||
kIOMessageSystemWillSleep uintptr = 0xe0000280
|
|
||||||
kIOMessageSystemHasPoweredOn uintptr = 0xe0000300
|
|
||||||
)
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
ioKit iokitFuncs
|
|
||||||
cf cfFuncs
|
|
||||||
cfCommonModes uintptr
|
|
||||||
|
|
||||||
libInitOnce sync.Once
|
|
||||||
libInitErr error
|
|
||||||
|
|
||||||
// callbackThunk is the single C-callable trampoline registered with IOKit.
|
|
||||||
callbackThunk uintptr
|
|
||||||
|
|
||||||
serviceRegistry = make(map[*Detector]struct{})
|
serviceRegistry = make(map[*Detector]struct{})
|
||||||
serviceRegistryMu sync.Mutex
|
serviceRegistryMu sync.Mutex
|
||||||
session *runLoopSession
|
|
||||||
|
|
||||||
// lifecycleMu serializes Register/Deregister so a new registration can't
|
|
||||||
// start a second runloop while a previous teardown is still pending.
|
|
||||||
lifecycleMu sync.Mutex
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// iokitFuncs holds IOKit symbols resolved once at init.
|
//export sleepCallbackBridge
|
||||||
type iokitFuncs struct {
|
func sleepCallbackBridge() {
|
||||||
IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr
|
log.Info("sleepCallbackBridge event triggered")
|
||||||
IODeregisterForSystemPower func(notifier *uintptr) int32
|
|
||||||
IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32
|
|
||||||
IOServiceClose func(connect uintptr) int32
|
|
||||||
IONotificationPortGetRunLoopSource func(port uintptr) uintptr
|
|
||||||
IONotificationPortDestroy func(port uintptr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// cfFuncs holds CoreFoundation symbols resolved once at init.
|
|
||||||
type cfFuncs struct {
|
|
||||||
CFRunLoopGetCurrent func() uintptr
|
|
||||||
CFRunLoopRun func()
|
|
||||||
CFRunLoopStop func(rl uintptr)
|
|
||||||
CFRunLoopAddSource func(rl, source, mode uintptr)
|
|
||||||
CFRunLoopRemoveSource func(rl, source, mode uintptr)
|
|
||||||
}
|
|
||||||
|
|
||||||
// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil
|
|
||||||
// session means no runloop is active and the next Register must start one.
|
|
||||||
type runLoopSession struct {
|
|
||||||
rl uintptr
|
|
||||||
port uintptr
|
|
||||||
notifier uintptr
|
|
||||||
rp uintptr
|
|
||||||
}
|
|
||||||
|
|
||||||
// detectorSnapshot pins a detector's callback and done channel so dispatch
|
|
||||||
// runs with values valid at snapshot time, even if a concurrent
|
|
||||||
// Deregister/Register rewrites the detector's fields.
|
|
||||||
type detectorSnapshot struct {
|
|
||||||
detector *Detector
|
|
||||||
callback func(event EventType)
|
|
||||||
done <-chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Detector delivers sleep and wake events to a registered callback.
|
|
||||||
type Detector struct {
|
|
||||||
callback func(event EventType)
|
|
||||||
done chan struct{}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Register installs callback for power events. The first registration starts
|
|
||||||
// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit
|
|
||||||
// registration succeeds or fails; subsequent registrations just add to the
|
|
||||||
// dispatch set.
|
|
||||||
func (d *Detector) Register(callback func(event EventType)) error {
|
|
||||||
lifecycleMu.Lock()
|
|
||||||
defer lifecycleMu.Unlock()
|
|
||||||
|
|
||||||
serviceRegistryMu.Lock()
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeSleep)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
//export resumedCallbackBridge
|
||||||
|
func resumedCallbackBridge() {
|
||||||
|
log.Info("resumedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export suspendedCallbackBridge
|
||||||
|
func suspendedCallbackBridge() {
|
||||||
|
log.Info("suspendedCallbackBridge event triggered")
|
||||||
|
}
|
||||||
|
|
||||||
|
//export poweredOnCallbackBridge
|
||||||
|
func poweredOnCallbackBridge() {
|
||||||
|
log.Info("poweredOnCallbackBridge event triggered")
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
|
for svc := range serviceRegistry {
|
||||||
|
svc.triggerCallback(EventTypeWakeUp)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
type Detector struct {
|
||||||
|
callback func(event EventType)
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewDetector() (*Detector, error) {
|
||||||
|
return &Detector{}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (d *Detector) Register(callback func(event EventType)) error {
|
||||||
|
serviceRegistryMu.Lock()
|
||||||
|
defer serviceRegistryMu.Unlock()
|
||||||
|
|
||||||
if _, exists := serviceRegistry[d]; exists {
|
if _, exists := serviceRegistry[d]; exists {
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
return fmt.Errorf("detector service already registered")
|
return fmt.Errorf("detector service already registered")
|
||||||
}
|
}
|
||||||
d.callback = callback
|
|
||||||
d.done = make(chan struct{})
|
|
||||||
serviceRegistry[d] = struct{}{}
|
|
||||||
needSetup := session == nil
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
|
|
||||||
if !needSetup {
|
d.callback = callback
|
||||||
|
|
||||||
|
d.ctx, d.cancel = context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
if len(serviceRegistry) > 0 {
|
||||||
|
serviceRegistry[d] = struct{}{}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
errCh := make(chan error, 1)
|
serviceRegistry[d] = struct{}{}
|
||||||
go runRunLoop(errCh)
|
|
||||||
if err := <-errCh; err != nil {
|
// CFRunLoop must run on a single fixed OS thread
|
||||||
serviceRegistryMu.Lock()
|
go func() {
|
||||||
delete(serviceRegistry, d)
|
runtime.LockOSThread()
|
||||||
close(d.done)
|
defer runtime.UnlockOSThread()
|
||||||
d.done = nil
|
|
||||||
serviceRegistryMu.Unlock()
|
C.registerNotifications()
|
||||||
return err
|
}()
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("sleep detection service started on macOS")
|
log.Info("sleep detection service started on macOS")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Deregister removes the detector. When the last detector leaves, IOKit
|
// Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
|
||||||
// notifications are torn down and the runloop is stopped.
|
// and the runloop is stopped and cleaned up.
|
||||||
func (d *Detector) Deregister() error {
|
func (d *Detector) Deregister() error {
|
||||||
lifecycleMu.Lock()
|
|
||||||
defer lifecycleMu.Unlock()
|
|
||||||
|
|
||||||
serviceRegistryMu.Lock()
|
serviceRegistryMu.Lock()
|
||||||
if _, exists := serviceRegistry[d]; !exists {
|
defer serviceRegistryMu.Unlock()
|
||||||
serviceRegistryMu.Unlock()
|
_, exists := serviceRegistry[d]
|
||||||
|
if !exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
close(d.done)
|
|
||||||
|
// cancel and remove this detector
|
||||||
|
d.cancel()
|
||||||
delete(serviceRegistry, d)
|
delete(serviceRegistry, d)
|
||||||
|
|
||||||
|
// If other Detectors still exist, leave IOKit running
|
||||||
if len(serviceRegistry) > 0 {
|
if len(serviceRegistry) > 0 {
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
sess := session
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
|
|
||||||
log.Info("sleep detection service stopping (deregister)")
|
log.Info("sleep detection service stopping (deregister)")
|
||||||
|
|
||||||
if sess == nil {
|
// Deregister IOKit notifications, stop runloop, and free resources
|
||||||
return nil
|
C.unregisterNotifications()
|
||||||
}
|
|
||||||
|
|
||||||
if sess.rl != 0 && sess.port != 0 {
|
|
||||||
source := ioKit.IONotificationPortGetRunLoopSource(sess.port)
|
|
||||||
cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes)
|
|
||||||
}
|
|
||||||
if sess.notifier != 0 {
|
|
||||||
n := sess.notifier
|
|
||||||
ioKit.IODeregisterForSystemPower(&n)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Clear session only after IODeregisterForSystemPower returns so any
|
|
||||||
// in-flight powerCallback can still look up session.rp to ack sleep.
|
|
||||||
serviceRegistryMu.Lock()
|
|
||||||
session = nil
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
|
|
||||||
if sess.rp != 0 {
|
|
||||||
ioKit.IOServiceClose(sess.rp)
|
|
||||||
}
|
|
||||||
if sess.port != 0 {
|
|
||||||
ioKit.IONotificationPortDestroy(sess.port)
|
|
||||||
}
|
|
||||||
if sess.rl != 0 {
|
|
||||||
cf.CFRunLoopStop(sess.rl)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) {
|
func (d *Detector) triggerCallback(event EventType) {
|
||||||
if cb == nil || done == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-done:
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
doneChan := make(chan struct{})
|
doneChan := make(chan struct{})
|
||||||
|
|
||||||
timeout := time.NewTimer(500 * time.Millisecond)
|
timeout := time.NewTimer(500 * time.Millisecond)
|
||||||
defer timeout.Stop()
|
defer timeout.Stop()
|
||||||
|
|
||||||
go func() {
|
cb := d.callback
|
||||||
defer close(doneChan)
|
go func(callback func(event EventType)) {
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
log.Errorf("panic in sleep callback: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
log.Info("sleep detection event fired")
|
log.Info("sleep detection event fired")
|
||||||
cb(event)
|
callback(event)
|
||||||
}()
|
close(doneChan)
|
||||||
|
}(cb)
|
||||||
|
|
||||||
select {
|
select {
|
||||||
case <-doneChan:
|
case <-doneChan:
|
||||||
case <-done:
|
case <-d.ctx.Done():
|
||||||
case <-timeout.C:
|
case <-timeout.C:
|
||||||
log.Warn("sleep callback timed out")
|
log.Warnf("sleep callback timed out")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector.
|
|
||||||
func NewDetector() (*Detector, error) {
|
|
||||||
if err := initLibs(); err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
return &Detector{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func initLibs() error {
|
|
||||||
libInitOnce.Do(func() {
|
|
||||||
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
|
||||||
if err != nil {
|
|
||||||
libInitErr = fmt.Errorf("dlopen IOKit: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
|
|
||||||
if err != nil {
|
|
||||||
libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource")
|
|
||||||
purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy")
|
|
||||||
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent")
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun")
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop")
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource")
|
|
||||||
purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource")
|
|
||||||
|
|
||||||
modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes")
|
|
||||||
if err != nil {
|
|
||||||
libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
// Launder the uintptr-to-pointer conversion through a Go variable so
|
|
||||||
// go vet's unsafeptr analyzer doesn't flag a system-library global.
|
|
||||||
cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr))
|
|
||||||
|
|
||||||
// NewCallback slots are a finite, non-reclaimable resource, so register
|
|
||||||
// a single thunk that dispatches to the current Detector set.
|
|
||||||
callbackThunk = purego.NewCallback(powerCallback)
|
|
||||||
})
|
|
||||||
return libInitErr
|
|
||||||
}
|
|
||||||
|
|
||||||
// powerCallback is the IOServiceInterestCallback trampoline, invoked on the
|
|
||||||
// runloop thread. A Go panic crossing the purego boundary has undefined
|
|
||||||
// behavior, so contain it here.
|
|
||||||
func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
log.Errorf("panic in sleep powerCallback: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
switch messageType {
|
|
||||||
case kIOMessageCanSystemSleep:
|
|
||||||
// Not acknowledging forces a 30s IOKit timeout before idle sleep.
|
|
||||||
allowPowerChange(messageArgument)
|
|
||||||
case kIOMessageSystemWillSleep:
|
|
||||||
dispatchEvent(EventTypeSleep)
|
|
||||||
allowPowerChange(messageArgument)
|
|
||||||
case kIOMessageSystemHasPoweredOn:
|
|
||||||
dispatchEvent(EventTypeWakeUp)
|
|
||||||
}
|
|
||||||
return 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func allowPowerChange(messageArgument uintptr) {
|
|
||||||
serviceRegistryMu.Lock()
|
|
||||||
var port uintptr
|
|
||||||
if session != nil {
|
|
||||||
port = session.rp
|
|
||||||
}
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
if port != 0 {
|
|
||||||
ioKit.IOAllowPowerChange(port, messageArgument)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func dispatchEvent(event EventType) {
|
|
||||||
serviceRegistryMu.Lock()
|
|
||||||
snaps := make([]detectorSnapshot, 0, len(serviceRegistry))
|
|
||||||
for d := range serviceRegistry {
|
|
||||||
snaps = append(snaps, detectorSnapshot{
|
|
||||||
detector: d,
|
|
||||||
callback: d.callback,
|
|
||||||
done: d.done,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
|
|
||||||
for _, s := range snaps {
|
|
||||||
s.detector.triggerCallback(event, s.callback, s.done)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup
|
|
||||||
// result is reported on errCh so Register can surface failures synchronously.
|
|
||||||
func runRunLoop(errCh chan<- error) {
|
|
||||||
runtime.LockOSThread()
|
|
||||||
defer runtime.UnlockOSThread()
|
|
||||||
|
|
||||||
sess, err := setupSession()
|
|
||||||
if err == nil {
|
|
||||||
serviceRegistryMu.Lock()
|
|
||||||
session = sess
|
|
||||||
serviceRegistryMu.Unlock()
|
|
||||||
}
|
|
||||||
errCh <- err
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
log.Errorf("panic in sleep runloop: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
cf.CFRunLoopRun()
|
|
||||||
}
|
|
||||||
|
|
||||||
// setupSession performs the IOKit registration on the current thread. Panics
|
|
||||||
// are converted to errors so runRunLoop never leaves errCh unsent.
|
|
||||||
func setupSession() (s *runLoopSession, err error) {
|
|
||||||
defer func() {
|
|
||||||
if r := recover(); r != nil {
|
|
||||||
err = fmt.Errorf("panic during runloop setup: %v", r)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
var portRef, notifier uintptr
|
|
||||||
rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, ¬ifier)
|
|
||||||
if rp == 0 {
|
|
||||||
return nil, fmt.Errorf("IORegisterForSystemPower returned zero")
|
|
||||||
}
|
|
||||||
|
|
||||||
rl := cf.CFRunLoopGetCurrent()
|
|
||||||
source := ioKit.IONotificationPortGetRunLoopSource(portRef)
|
|
||||||
cf.CFRunLoopAddSource(rl, source, cfCommonModes)
|
|
||||||
|
|
||||||
return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -27,10 +28,6 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
|
|||||||
|
|
||||||
// Start begins monitoring the WireGuard interface.
|
// Start begins monitoring the WireGuard interface.
|
||||||
// It relies on the provided context cancellation to stop.
|
// It relies on the provided context cancellation to stop.
|
||||||
//
|
|
||||||
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
|
|
||||||
// to avoid the allocation churn of repeatedly dumping the kernel link
|
|
||||||
// table; on other platforms it falls back to a low-frequency poll.
|
|
||||||
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
|
||||||
defer close(m.done)
|
defer close(m.done)
|
||||||
|
|
||||||
@@ -59,7 +56,31 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
|
|||||||
|
|
||||||
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
|
||||||
|
|
||||||
return watchInterface(ctx, ifaceName, expectedIndex)
|
ticker := time.NewTicker(2 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
||||||
|
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
|
||||||
|
case <-ticker.C:
|
||||||
|
currentIndex, err := getInterfaceIndex(ifaceName)
|
||||||
|
if err != nil {
|
||||||
|
// Interface was deleted
|
||||||
|
log.Infof("Interface monitor: %s deleted", ifaceName)
|
||||||
|
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if interface index changed (interface was recreated)
|
||||||
|
if currentIndex != expectedIndex {
|
||||||
|
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
||||||
|
ifaceName, expectedIndex, currentIndex)
|
||||||
|
return true, nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getInterfaceIndex returns the index of a network interface by name.
|
// getInterfaceIndex returns the index of a network interface by name.
|
||||||
|
|||||||
@@ -1,134 +0,0 @@
|
|||||||
//go:build linux
|
|
||||||
|
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"syscall"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/vishvananda/netlink"
|
|
||||||
)
|
|
||||||
|
|
||||||
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
|
|
||||||
// deletion or recreation of the WireGuard interface.
|
|
||||||
//
|
|
||||||
// The previous implementation polled net.InterfaceByName every 2 s, which
|
|
||||||
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
|
|
||||||
// entire kernel link table on every call. On hosts with many veth
|
|
||||||
// interfaces (containers, bridges) the resulting allocation churn was on
|
|
||||||
// the order of ~1 GB/day from this single ticker, which on small ARM
|
|
||||||
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
|
|
||||||
//
|
|
||||||
// The event-driven version below allocates only when the kernel actually
|
|
||||||
// publishes a link event for the tracked interface — typically zero
|
|
||||||
// allocations between events.
|
|
||||||
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
|
||||||
done := make(chan struct{})
|
|
||||||
defer close(done)
|
|
||||||
|
|
||||||
// Buffer the channel to absorb event bursts (e.g. when many veth
|
|
||||||
// pairs are created/destroyed at once by container runtimes).
|
|
||||||
linkChan := make(chan netlink.LinkUpdate, 32)
|
|
||||||
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
|
|
||||||
// Return shouldRestart=true so the engine recovers monitoring
|
|
||||||
// via triggerClientRestart instead of silently losing it for
|
|
||||||
// the rest of the process lifetime.
|
|
||||||
return true, fmt.Errorf("subscribe to link updates: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Race window: the interface could have been deleted (or recreated)
|
|
||||||
// between the initial getInterfaceIndex() in Start and LinkSubscribe
|
|
||||||
// completing its handshake with the kernel. Re-check explicitly so we
|
|
||||||
// do not block forever waiting for an event that already fired.
|
|
||||||
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
|
|
||||||
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
|
|
||||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
|
||||||
} else if currentIndex != expectedIndex {
|
|
||||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
|
|
||||||
ifaceName, expectedIndex, currentIndex)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
|
||||||
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
|
||||||
|
|
||||||
case update, ok := <-linkChan:
|
|
||||||
if !ok {
|
|
||||||
// The vishvananda/netlink subscription goroutine closes
|
|
||||||
// the channel on receive errors. Signal the engine to
|
|
||||||
// restart so monitoring is re-established instead of
|
|
||||||
// silently ending.
|
|
||||||
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
|
|
||||||
return true, fmt.Errorf("link subscription channel closed unexpectedly")
|
|
||||||
}
|
|
||||||
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
|
|
||||||
return true, err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// inspectLinkEvent classifies a single netlink link update against the
|
|
||||||
// tracked WireGuard interface. It returns (true, err) when the engine
|
|
||||||
// should restart monitoring; (false, nil) means the event is unrelated
|
|
||||||
// and the caller should keep waiting.
|
|
||||||
//
|
|
||||||
// The error component, when non-nil, describes the kernel-side reason
|
|
||||||
// (deletion or rename); the recreation case returns (true, nil) since
|
|
||||||
// no error condition is reported.
|
|
||||||
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
|
|
||||||
eventIndex := int(update.Index)
|
|
||||||
eventName := ""
|
|
||||||
if attrs := update.Attrs(); attrs != nil {
|
|
||||||
eventName = attrs.Name
|
|
||||||
}
|
|
||||||
|
|
||||||
switch update.Header.Type {
|
|
||||||
case syscall.RTM_DELLINK:
|
|
||||||
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
|
|
||||||
case syscall.RTM_NEWLINK:
|
|
||||||
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
|
|
||||||
// tracked interface index.
|
|
||||||
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
|
|
||||||
if eventIndex != expectedIndex {
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
|
||||||
return true, fmt.Errorf("interface %s deleted", ifaceName)
|
|
||||||
}
|
|
||||||
|
|
||||||
// inspectNewLink reports a restart when an RTM_NEWLINK either:
|
|
||||||
//
|
|
||||||
// 1. Introduces a link with our name at a different index (recreation
|
|
||||||
// after a delete), or
|
|
||||||
//
|
|
||||||
// 2. Reports a link still at our index but with a different name
|
|
||||||
// (in-place rename). The previous polling implementation caught
|
|
||||||
// this implicitly because net.InterfaceByName(ifaceName) would
|
|
||||||
// start failing; the event-driven version has to test it.
|
|
||||||
//
|
|
||||||
// Same name + same index is just a flag/state change on the existing
|
|
||||||
// interface and is ignored.
|
|
||||||
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
|
|
||||||
if eventName == ifaceName && eventIndex != expectedIndex {
|
|
||||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
|
||||||
ifaceName, expectedIndex, eventIndex)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
|
|
||||||
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
|
|
||||||
ifaceName, eventName, expectedIndex)
|
|
||||||
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
|
|
||||||
}
|
|
||||||
return false, nil
|
|
||||||
}
|
|
||||||
@@ -1,56 +0,0 @@
|
|||||||
//go:build !linux
|
|
||||||
|
|
||||||
package internal
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
// watchInterface polls net.InterfaceByName at a fixed interval to detect
|
|
||||||
// deletion or recreation of the WireGuard interface.
|
|
||||||
//
|
|
||||||
// This is the fallback used on non-Linux desktop and server platforms
|
|
||||||
// (darwin, windows, freebsd). It is also compiled on android and ios so
|
|
||||||
// the package builds on every supported GOOS, but it is never reached
|
|
||||||
// at runtime there because Start() in wg_iface_monitor.go exits early
|
|
||||||
// on mobile platforms.
|
|
||||||
//
|
|
||||||
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
|
|
||||||
// RTNLGRP_LINK netlink subscription instead, because on Linux
|
|
||||||
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
|
|
||||||
// dumps the entire kernel link table on every call and produces
|
|
||||||
// significant allocation churn (netbirdio/netbird#3678).
|
|
||||||
//
|
|
||||||
// Windows is also reported in #3678 as affected by RSS climb. A future
|
|
||||||
// follow-up could implement an event-driven watcher there using
|
|
||||||
// NotifyIpInterfaceChange from iphlpapi.
|
|
||||||
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
|
|
||||||
ticker := time.NewTicker(2 * time.Second)
|
|
||||||
defer ticker.Stop()
|
|
||||||
|
|
||||||
for {
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
log.Infof("Interface monitor: stopped for %s", ifaceName)
|
|
||||||
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
|
|
||||||
case <-ticker.C:
|
|
||||||
currentIndex, err := getInterfaceIndex(ifaceName)
|
|
||||||
if err != nil {
|
|
||||||
// Interface was deleted
|
|
||||||
log.Infof("Interface monitor: %s deleted", ifaceName)
|
|
||||||
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if interface index changed (interface was recreated)
|
|
||||||
if currentIndex != expectedIndex {
|
|
||||||
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
|
|
||||||
ifaceName, expectedIndex, currentIndex)
|
|
||||||
return true, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -13,25 +13,15 @@
|
|||||||
|
|
||||||
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
|
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
|
||||||
|
|
||||||
<!-- Autostart: enabled by default, disable with AUTOSTART=0 on the msiexec command line -->
|
|
||||||
<Property Id="AUTOSTART" Value="1" />
|
|
||||||
|
|
||||||
<StandardDirectory Id="ProgramFiles64Folder">
|
<StandardDirectory Id="ProgramFiles64Folder">
|
||||||
<Directory Id="NetbirdInstallDir" Name="Netbird">
|
<Directory Id="NetbirdInstallDir" Name="Netbird">
|
||||||
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
|
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
|
||||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
|
||||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
|
||||||
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
|
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||||
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
|
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
|
||||||
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
|
||||||
</Shortcut>
|
|
||||||
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
|
|
||||||
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
|
|
||||||
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
|
|
||||||
</Shortcut>
|
|
||||||
</File>
|
</File>
|
||||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
|
||||||
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
|
|
||||||
<?if $(var.ArchSuffix) = "amd64" ?>
|
<?if $(var.ArchSuffix) = "amd64" ?>
|
||||||
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
|
||||||
<?endif ?>
|
<?endif ?>
|
||||||
@@ -56,30 +46,8 @@
|
|||||||
</Directory>
|
</Directory>
|
||||||
</StandardDirectory>
|
</StandardDirectory>
|
||||||
|
|
||||||
<!-- Per-user component: HKCU keypath (auto GUID via "*"), separate from
|
|
||||||
the per-machine NetbirdFiles component to satisfy ICE57. -->
|
|
||||||
<StandardDirectory Id="ProgramMenuFolder">
|
|
||||||
<Component Id="NetbirdAumidRegistry" Guid="*">
|
|
||||||
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
|
|
||||||
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
|
|
||||||
</RegistryKey>
|
|
||||||
</Component>
|
|
||||||
</StandardDirectory>
|
|
||||||
|
|
||||||
<StandardDirectory Id="CommonAppDataFolder">
|
|
||||||
<Directory Id="NetbirdAutoStartDir" Name="Netbird">
|
|
||||||
<Component Id="NetbirdAutoStart" Guid="b199eaca-b0dd-4032-af19-679cfad48eb3" Bitness="always64" Condition='AUTOSTART = "1"'>
|
|
||||||
<RegistryValue Root="HKLM" Key="Software\Microsoft\Windows\CurrentVersion\Run"
|
|
||||||
Name="Netbird" Value=""[NetbirdInstallDir]netbird-ui.exe""
|
|
||||||
Type="string" KeyPath="yes" />
|
|
||||||
</Component>
|
|
||||||
</Directory>
|
|
||||||
</StandardDirectory>
|
|
||||||
|
|
||||||
<ComponentGroup Id="NetbirdFilesComponent">
|
<ComponentGroup Id="NetbirdFilesComponent">
|
||||||
<ComponentRef Id="NetbirdFiles" />
|
<ComponentRef Id="NetbirdFiles" />
|
||||||
<ComponentRef Id="NetbirdAumidRegistry" />
|
|
||||||
<ComponentRef Id="NetbirdAutoStart" />
|
|
||||||
</ComponentGroup>
|
</ComponentGroup>
|
||||||
|
|
||||||
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -64,17 +64,6 @@ service DaemonService {
|
|||||||
|
|
||||||
rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {}
|
rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {}
|
||||||
|
|
||||||
// StartCapture begins streaming packet capture on the WireGuard interface.
|
|
||||||
// Requires --enable-capture set at service install/reconfigure time.
|
|
||||||
rpc StartCapture(StartCaptureRequest) returns (stream CapturePacket) {}
|
|
||||||
|
|
||||||
// StartBundleCapture begins capturing packets to a server-side temp file
|
|
||||||
// for inclusion in the next debug bundle. Auto-stops after the given timeout.
|
|
||||||
rpc StartBundleCapture(StartBundleCaptureRequest) returns (StartBundleCaptureResponse) {}
|
|
||||||
|
|
||||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
|
||||||
rpc StopBundleCapture(StopBundleCaptureRequest) returns (StopBundleCaptureResponse) {}
|
|
||||||
|
|
||||||
rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {}
|
rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {}
|
||||||
|
|
||||||
rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {}
|
rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {}
|
||||||
@@ -115,6 +104,8 @@ service DaemonService {
|
|||||||
// StopCPUProfile stops CPU profiling in the daemon
|
// StopCPUProfile stops CPU profiling in the daemon
|
||||||
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
|
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
|
||||||
|
|
||||||
|
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
|
||||||
|
|
||||||
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
|
||||||
|
|
||||||
// ExposeService exposes a local port via the NetBird reverse proxy
|
// ExposeService exposes a local port via the NetBird reverse proxy
|
||||||
@@ -123,6 +114,20 @@ service DaemonService {
|
|||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
message OSLifecycleRequest {
|
||||||
|
// avoid collision with loglevel enum
|
||||||
|
enum CycleType {
|
||||||
|
UNKNOWN = 0;
|
||||||
|
SLEEP = 1;
|
||||||
|
WAKEUP = 2;
|
||||||
|
}
|
||||||
|
|
||||||
|
CycleType type = 1;
|
||||||
|
}
|
||||||
|
|
||||||
|
message OSLifecycleResponse {}
|
||||||
|
|
||||||
|
|
||||||
message LoginRequest {
|
message LoginRequest {
|
||||||
// setupKey netbird setup key.
|
// setupKey netbird setup key.
|
||||||
string setupKey = 1;
|
string setupKey = 1;
|
||||||
@@ -843,26 +848,3 @@ message ExposeServiceReady {
|
|||||||
string domain = 3;
|
string domain = 3;
|
||||||
bool port_auto_assigned = 4;
|
bool port_auto_assigned = 4;
|
||||||
}
|
}
|
||||||
|
|
||||||
message StartCaptureRequest {
|
|
||||||
bool text_output = 1;
|
|
||||||
uint32 snap_len = 2;
|
|
||||||
google.protobuf.Duration duration = 3;
|
|
||||||
string filter_expr = 4;
|
|
||||||
bool verbose = 5;
|
|
||||||
bool ascii = 6;
|
|
||||||
}
|
|
||||||
|
|
||||||
message CapturePacket {
|
|
||||||
bytes data = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message StartBundleCaptureRequest {
|
|
||||||
// timeout auto-stops the capture after this duration.
|
|
||||||
// Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum.
|
|
||||||
google.protobuf.Duration timeout = 1;
|
|
||||||
}
|
|
||||||
|
|
||||||
message StartBundleCaptureResponse {}
|
|
||||||
message StopBundleCaptureRequest {}
|
|
||||||
message StopBundleCaptureResponse {}
|
|
||||||
|
|||||||
File diff suppressed because it is too large
Load Diff
@@ -1,365 +0,0 @@
|
|||||||
package server
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"io"
|
|
||||||
"os"
|
|
||||||
"sync"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"google.golang.org/grpc/codes"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
"github.com/netbirdio/netbird/util/capture"
|
|
||||||
)
|
|
||||||
|
|
||||||
const maxBundleCaptureDuration = 10 * time.Minute
|
|
||||||
|
|
||||||
// bundleCapture holds the state of an in-progress capture destined for the
|
|
||||||
// debug bundle. The lifecycle is:
|
|
||||||
//
|
|
||||||
// StartBundleCapture → capture running, writing to temp file
|
|
||||||
// StopBundleCapture → capture stopped, temp file available
|
|
||||||
// DebugBundle → temp file included in zip, then cleaned up
|
|
||||||
type bundleCapture struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
sess *capture.Session
|
|
||||||
file *os.File
|
|
||||||
engine *internal.Engine
|
|
||||||
cancel context.CancelFunc
|
|
||||||
stopped bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// stop halts the capture session and closes the pcap writer. Idempotent.
|
|
||||||
func (bc *bundleCapture) stop() {
|
|
||||||
bc.mu.Lock()
|
|
||||||
defer bc.mu.Unlock()
|
|
||||||
|
|
||||||
if bc.stopped {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
bc.stopped = true
|
|
||||||
|
|
||||||
if bc.cancel != nil {
|
|
||||||
bc.cancel()
|
|
||||||
}
|
|
||||||
if bc.sess != nil {
|
|
||||||
bc.sess.Stop()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// path returns the temp file path, or "" if no file exists.
|
|
||||||
func (bc *bundleCapture) path() string {
|
|
||||||
if bc.file == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
return bc.file.Name()
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanup removes the temp file.
|
|
||||||
func (bc *bundleCapture) cleanup() {
|
|
||||||
if bc.file == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
name := bc.file.Name()
|
|
||||||
if err := bc.file.Close(); err != nil {
|
|
||||||
log.Debugf("close bundle capture file: %v", err)
|
|
||||||
}
|
|
||||||
if err := os.Remove(name); err != nil && !os.IsNotExist(err) {
|
|
||||||
log.Debugf("remove bundle capture file: %v", err)
|
|
||||||
}
|
|
||||||
bc.file = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartCapture streams a pcap or text packet capture over gRPC.
|
|
||||||
// Gated by the --enable-capture service flag.
|
|
||||||
func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.DaemonService_StartCaptureServer) error {
|
|
||||||
if !s.captureEnabled {
|
|
||||||
return status.Error(codes.PermissionDenied,
|
|
||||||
"packet capture is disabled; reinstall or reconfigure the service with --enable-capture")
|
|
||||||
}
|
|
||||||
|
|
||||||
if d := req.GetDuration(); d != nil && d.AsDuration() < 0 {
|
|
||||||
return status.Error(codes.InvalidArgument, "duration must not be negative")
|
|
||||||
}
|
|
||||||
|
|
||||||
matcher, err := parseCaptureFilter(req)
|
|
||||||
if err != nil {
|
|
||||||
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
pr, pw := io.Pipe()
|
|
||||||
|
|
||||||
opts := capture.Options{
|
|
||||||
Matcher: matcher,
|
|
||||||
SnapLen: req.GetSnapLen(),
|
|
||||||
Verbose: req.GetVerbose(),
|
|
||||||
ASCII: req.GetAscii(),
|
|
||||||
}
|
|
||||||
if req.GetTextOutput() {
|
|
||||||
opts.TextOutput = pw
|
|
||||||
} else {
|
|
||||||
opts.Output = pw
|
|
||||||
}
|
|
||||||
|
|
||||||
sess, err := capture.NewSession(opts)
|
|
||||||
if err != nil {
|
|
||||||
pw.Close()
|
|
||||||
return status.Errorf(codes.Internal, "create capture session: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
engine, err := s.claimCapture(sess)
|
|
||||||
if err != nil {
|
|
||||||
sess.Stop()
|
|
||||||
pw.Close()
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := engine.SetCapture(sess); err != nil {
|
|
||||||
s.releaseCapture(sess)
|
|
||||||
sess.Stop()
|
|
||||||
pw.Close()
|
|
||||||
return status.Errorf(codes.Internal, "set capture: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Send an empty initial message to signal that the capture was accepted.
|
|
||||||
// The client waits for this before printing the banner, so it must arrive
|
|
||||||
// before any packet data.
|
|
||||||
if err := stream.Send(&proto.CapturePacket{}); err != nil {
|
|
||||||
s.clearCaptureIfOwner(sess, engine)
|
|
||||||
sess.Stop()
|
|
||||||
pw.Close()
|
|
||||||
return status.Errorf(codes.Internal, "send initial message: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ctx := stream.Context()
|
|
||||||
if d := req.GetDuration(); d != nil {
|
|
||||||
if dur := d.AsDuration(); dur > 0 {
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
ctx, cancel = context.WithTimeout(ctx, dur)
|
|
||||||
defer cancel()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
s.clearCaptureIfOwner(sess, engine)
|
|
||||||
sess.Stop()
|
|
||||||
pw.Close()
|
|
||||||
}()
|
|
||||||
defer pr.Close()
|
|
||||||
|
|
||||||
log.Infof("packet capture started (text=%v, expr=%q)", req.GetTextOutput(), req.GetFilterExpr())
|
|
||||||
defer func() {
|
|
||||||
stats := sess.Stats()
|
|
||||||
log.Infof("packet capture stopped: %d packets, %d bytes, %d dropped",
|
|
||||||
stats.Packets, stats.Bytes, stats.Dropped)
|
|
||||||
}()
|
|
||||||
|
|
||||||
return streamToGRPC(pr, stream)
|
|
||||||
}
|
|
||||||
|
|
||||||
func streamToGRPC(r io.Reader, stream proto.DaemonService_StartCaptureServer) error {
|
|
||||||
buf := make([]byte, 32*1024)
|
|
||||||
for {
|
|
||||||
n, readErr := r.Read(buf)
|
|
||||||
if n > 0 {
|
|
||||||
if err := stream.Send(&proto.CapturePacket{Data: buf[:n]}); err != nil {
|
|
||||||
log.Debugf("capture stream send: %v", err)
|
|
||||||
return nil //nolint:nilerr // client disconnected
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if readErr != nil {
|
|
||||||
return nil //nolint:nilerr // pipe closed, capture stopped normally
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartBundleCapture begins capturing packets to a server-side temp file for
|
|
||||||
// inclusion in the next debug bundle. Not gated by --enable-capture since the
|
|
||||||
// output stays on the server (same trust level as CPU profiling).
|
|
||||||
//
|
|
||||||
// A timeout auto-stops the capture as a safety net if StopBundleCapture is
|
|
||||||
// never called (e.g. CLI crash).
|
|
||||||
func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCaptureRequest) (*proto.StartBundleCaptureResponse, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
s.stopBundleCaptureLocked()
|
|
||||||
s.cleanupBundleCapture()
|
|
||||||
|
|
||||||
if s.activeCapture != nil {
|
|
||||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
|
||||||
}
|
|
||||||
|
|
||||||
engine, err := s.getCaptureEngineLocked()
|
|
||||||
if err != nil {
|
|
||||||
// Not fatal: kernel mode or not connected. Log and return success
|
|
||||||
// so the debug bundle still generates without capture data.
|
|
||||||
log.Warnf("packet capture unavailable, skipping: %v", err)
|
|
||||||
return &proto.StartBundleCaptureResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
timeout := req.GetTimeout().AsDuration()
|
|
||||||
if timeout <= 0 || timeout > maxBundleCaptureDuration {
|
|
||||||
timeout = maxBundleCaptureDuration
|
|
||||||
}
|
|
||||||
|
|
||||||
f, err := os.CreateTemp("", "netbird.capture.*.pcap")
|
|
||||||
if err != nil {
|
|
||||||
return nil, status.Errorf(codes.Internal, "create temp file: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sess, err := capture.NewSession(capture.Options{Output: f})
|
|
||||||
if err != nil {
|
|
||||||
f.Close()
|
|
||||||
os.Remove(f.Name())
|
|
||||||
return nil, status.Errorf(codes.Internal, "create capture session: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := engine.SetCapture(sess); err != nil {
|
|
||||||
sess.Stop()
|
|
||||||
f.Close()
|
|
||||||
os.Remove(f.Name())
|
|
||||||
log.Warnf("packet capture unavailable (no filtered device), skipping: %v", err)
|
|
||||||
return &proto.StartBundleCaptureResponse{}, nil
|
|
||||||
}
|
|
||||||
s.activeCapture = sess
|
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), timeout)
|
|
||||||
bc := &bundleCapture{
|
|
||||||
sess: sess,
|
|
||||||
file: f,
|
|
||||||
engine: engine,
|
|
||||||
cancel: cancel,
|
|
||||||
}
|
|
||||||
|
|
||||||
s.bundleCapture = bc
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-ctx.Done()
|
|
||||||
s.mutex.Lock()
|
|
||||||
if s.bundleCapture == bc {
|
|
||||||
s.stopBundleCaptureLocked()
|
|
||||||
} else {
|
|
||||||
bc.stop()
|
|
||||||
}
|
|
||||||
s.mutex.Unlock()
|
|
||||||
log.Infof("bundle capture auto-stopped after timeout")
|
|
||||||
}()
|
|
||||||
log.Infof("bundle capture started (timeout=%s, file=%s)", timeout, f.Name())
|
|
||||||
|
|
||||||
return &proto.StartBundleCaptureResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StopBundleCapture stops the running bundle capture. Idempotent.
|
|
||||||
func (s *Server) StopBundleCapture(_ context.Context, _ *proto.StopBundleCaptureRequest) (*proto.StopBundleCaptureResponse, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
s.stopBundleCaptureLocked()
|
|
||||||
return &proto.StopBundleCaptureResponse{}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// stopBundleCaptureLocked stops the bundle capture if running. Must hold s.mutex.
|
|
||||||
func (s *Server) stopBundleCaptureLocked() {
|
|
||||||
if s.bundleCapture == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
bc := s.bundleCapture
|
|
||||||
if bc.engine != nil && s.activeCapture == bc.sess {
|
|
||||||
if err := bc.engine.SetCapture(nil); err != nil {
|
|
||||||
log.Debugf("clear bundle capture: %v", err)
|
|
||||||
}
|
|
||||||
s.activeCapture = nil
|
|
||||||
}
|
|
||||||
bc.stop()
|
|
||||||
|
|
||||||
stats := bc.sess.Stats()
|
|
||||||
log.Infof("bundle capture stopped: %d packets, %d bytes, %d dropped",
|
|
||||||
stats.Packets, stats.Bytes, stats.Dropped)
|
|
||||||
}
|
|
||||||
|
|
||||||
// bundleCapturePath returns the temp file path if a capture has been taken,
|
|
||||||
// stops any running capture, and returns "". Called from DebugBundle.
|
|
||||||
// Must hold s.mutex.
|
|
||||||
func (s *Server) bundleCapturePath() string {
|
|
||||||
if s.bundleCapture == nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
|
|
||||||
s.bundleCapture.stop()
|
|
||||||
return s.bundleCapture.path()
|
|
||||||
}
|
|
||||||
|
|
||||||
// cleanupBundleCapture removes the temp file and clears state. Must hold s.mutex.
|
|
||||||
func (s *Server) cleanupBundleCapture() {
|
|
||||||
if s.bundleCapture == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
s.bundleCapture.cleanup()
|
|
||||||
s.bundleCapture = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// claimCapture reserves the engine's capture slot for sess. Returns
|
|
||||||
// FailedPrecondition if another capture is already active.
|
|
||||||
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
|
|
||||||
if s.activeCapture != nil {
|
|
||||||
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
|
|
||||||
}
|
|
||||||
engine, err := s.getCaptureEngineLocked()
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
s.activeCapture = sess
|
|
||||||
return engine, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// releaseCapture clears the active-capture owner if it still matches sess.
|
|
||||||
func (s *Server) releaseCapture(sess *capture.Session) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
if s.activeCapture == sess {
|
|
||||||
s.activeCapture = nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// clearCaptureIfOwner clears engine's capture slot only if sess still owns it.
|
|
||||||
func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Engine) {
|
|
||||||
s.mutex.Lock()
|
|
||||||
defer s.mutex.Unlock()
|
|
||||||
if s.activeCapture != sess {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if err := engine.SetCapture(nil); err != nil {
|
|
||||||
log.Debugf("clear capture: %v", err)
|
|
||||||
}
|
|
||||||
s.activeCapture = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
|
|
||||||
if s.connectClient == nil {
|
|
||||||
return nil, status.Error(codes.FailedPrecondition, "client not connected")
|
|
||||||
}
|
|
||||||
engine := s.connectClient.Engine()
|
|
||||||
if engine == nil {
|
|
||||||
return nil, status.Error(codes.FailedPrecondition, "engine not initialized")
|
|
||||||
}
|
|
||||||
return engine, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseCaptureFilter returns a Matcher from the request.
|
|
||||||
// Returns nil (match all) when no filter expression is set.
|
|
||||||
func parseCaptureFilter(req *proto.StartCaptureRequest) (capture.Matcher, error) {
|
|
||||||
expr := req.GetFilterExpr()
|
|
||||||
if expr == "" {
|
|
||||||
return nil, nil //nolint:nilnil // nil Matcher means "match all"
|
|
||||||
}
|
|
||||||
return capture.ParseFilter(expr)
|
|
||||||
}
|
|
||||||
@@ -43,9 +43,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
|||||||
}()
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
capturePath := s.bundleCapturePath()
|
// Prepare refresh callback for health probes
|
||||||
defer s.cleanupBundleCapture()
|
|
||||||
|
|
||||||
var refreshStatus func()
|
var refreshStatus func()
|
||||||
if s.connectClient != nil {
|
if s.connectClient != nil {
|
||||||
engine := s.connectClient.Engine()
|
engine := s.connectClient.Engine()
|
||||||
@@ -64,7 +62,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
|
|||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: s.logFile,
|
LogPath: s.logFile,
|
||||||
CPUProfile: cpuProfileData,
|
CPUProfile: cpuProfileData,
|
||||||
CapturePath: capturePath,
|
|
||||||
RefreshStatus: refreshStatus,
|
RefreshStatus: refreshStatus,
|
||||||
ClientMetrics: clientMetrics,
|
ClientMetrics: clientMetrics,
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -33,7 +33,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/updater"
|
"github.com/netbirdio/netbird/client/internal/updater"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/util/capture"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -90,11 +89,7 @@ type Server struct {
|
|||||||
profileManager *profilemanager.ServiceManager
|
profileManager *profilemanager.ServiceManager
|
||||||
profilesDisabled bool
|
profilesDisabled bool
|
||||||
updateSettingsDisabled bool
|
updateSettingsDisabled bool
|
||||||
captureEnabled bool
|
networksDisabled bool
|
||||||
bundleCapture *bundleCapture
|
|
||||||
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
|
|
||||||
activeCapture *capture.Session
|
|
||||||
networksDisabled bool
|
|
||||||
|
|
||||||
sleepHandler *sleephandler.SleepHandler
|
sleepHandler *sleephandler.SleepHandler
|
||||||
|
|
||||||
@@ -111,7 +106,7 @@ type oauthAuthFlow struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// New server instance constructor.
|
// New server instance constructor.
|
||||||
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool, networksDisabled bool) *Server {
|
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server {
|
||||||
s := &Server{
|
s := &Server{
|
||||||
rootCtx: ctx,
|
rootCtx: ctx,
|
||||||
logFile: logFile,
|
logFile: logFile,
|
||||||
@@ -120,13 +115,11 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
|
|||||||
profileManager: profilemanager.NewServiceManager(configFile),
|
profileManager: profilemanager.NewServiceManager(configFile),
|
||||||
profilesDisabled: profilesDisabled,
|
profilesDisabled: profilesDisabled,
|
||||||
updateSettingsDisabled: updateSettingsDisabled,
|
updateSettingsDisabled: updateSettingsDisabled,
|
||||||
captureEnabled: captureEnabled,
|
|
||||||
networksDisabled: networksDisabled,
|
networksDisabled: networksDisabled,
|
||||||
jwtCache: newJWTCache(),
|
jwtCache: newJWTCache(),
|
||||||
}
|
}
|
||||||
agent := &serverAgent{s}
|
agent := &serverAgent{s}
|
||||||
s.sleepHandler = sleephandler.New(agent)
|
s.sleepHandler = sleephandler.New(agent)
|
||||||
s.startSleepDetector()
|
|
||||||
|
|
||||||
return s
|
return s
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -104,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
|
|||||||
t.Fatalf("failed to set active profile state: %v", err)
|
t.Fatalf("failed to set active profile state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := New(ctx, "debug", "", false, false, false, false)
|
s := New(ctx, "debug", "", false, false, false)
|
||||||
|
|
||||||
s.config = config
|
s.config = config
|
||||||
|
|
||||||
@@ -165,7 +165,7 @@ func TestServer_Up(t *testing.T) {
|
|||||||
t.Fatalf("failed to set active profile state: %v", err)
|
t.Fatalf("failed to set active profile state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := New(ctx, "console", "", false, false, false, false)
|
s := New(ctx, "console", "", false, false, false)
|
||||||
err = s.Start()
|
err = s.Start()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -235,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
|
|||||||
t.Fatalf("failed to set active profile state: %v", err)
|
t.Fatalf("failed to set active profile state: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s := New(ctx, "console", "", false, false, false, false)
|
s := New(ctx, "console", "", false, false, false)
|
||||||
|
|
||||||
err = s.Start()
|
err = s.Start()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
@@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil)
|
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
s := New(ctx, "console", "", false, false, false, false)
|
s := New(ctx, "console", "", false, false, false)
|
||||||
|
|
||||||
rosenpassEnabled := true
|
rosenpassEnabled := true
|
||||||
rosenpassPermissive := true
|
rosenpassPermissive := true
|
||||||
|
|||||||
@@ -2,18 +2,13 @@ package server
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"os"
|
|
||||||
"strconv"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/sleep"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const envDisableSleepDetector = "NB_DISABLE_SLEEP_DETECTOR"
|
|
||||||
|
|
||||||
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
|
||||||
type serverAgent struct {
|
type serverAgent struct {
|
||||||
s *Server
|
s *Server
|
||||||
@@ -33,61 +28,19 @@ func (a *serverAgent) Status() (internal.StatusType, error) {
|
|||||||
return internal.CtxGetState(a.s.rootCtx).Status()
|
return internal.CtxGetState(a.s.rootCtx).Status()
|
||||||
}
|
}
|
||||||
|
|
||||||
// startSleepDetector starts the OS sleep/wake detector and forwards events to
|
// NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
|
||||||
// the sleep handler. On platforms without a supported detector the attempt
|
func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
|
||||||
// logs a warning and returns. Setting NB_DISABLE_SLEEP_DETECTOR=true skips
|
switch req.GetType() {
|
||||||
// registration entirely.
|
case proto.OSLifecycleRequest_WAKEUP:
|
||||||
func (s *Server) startSleepDetector() {
|
if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
|
||||||
if sleepDetectorDisabled() {
|
return &proto.OSLifecycleResponse{}, err
|
||||||
log.Info("sleep detection disabled via " + envDisableSleepDetector)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
svc, err := sleep.New()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to initialize sleep detection: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
err = svc.Register(func(event sleep.EventType) {
|
|
||||||
switch event {
|
|
||||||
case sleep.EventTypeSleep:
|
|
||||||
log.Info("handling sleep event")
|
|
||||||
if err := s.sleepHandler.HandleSleep(s.rootCtx); err != nil {
|
|
||||||
log.Errorf("failed to handle sleep event: %v", err)
|
|
||||||
}
|
|
||||||
case sleep.EventTypeWakeUp:
|
|
||||||
log.Info("handling wakeup event")
|
|
||||||
if err := s.sleepHandler.HandleWakeUp(s.rootCtx); err != nil {
|
|
||||||
log.Errorf("failed to handle wakeup event: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
})
|
case proto.OSLifecycleRequest_SLEEP:
|
||||||
if err != nil {
|
if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
|
||||||
log.Errorf("failed to register sleep detector: %v", err)
|
return &proto.OSLifecycleResponse{}, err
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Info("sleep detection service initialized")
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
<-s.rootCtx.Done()
|
|
||||||
log.Info("stopping sleep event listener")
|
|
||||||
if err := svc.Deregister(); err != nil {
|
|
||||||
log.Errorf("failed to deregister sleep detector: %v", err)
|
|
||||||
}
|
}
|
||||||
}()
|
default:
|
||||||
}
|
log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
|
||||||
|
|
||||||
func sleepDetectorDisabled() bool {
|
|
||||||
val := os.Getenv(envDisableSleepDetector)
|
|
||||||
if val == "" {
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
disabled, err := strconv.ParseBool(val)
|
return &proto.OSLifecycleResponse{}, nil
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed to parse %s=%q: %v", envDisableSleepDetector, val, err)
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return disabled
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -224,20 +224,15 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
|
|||||||
|
|
||||||
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
func (m *Manager) writeSSHConfig(sshConfig string) error {
|
||||||
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
|
||||||
sshConfigPathTmp := sshConfigPath + ".tmp"
|
|
||||||
|
|
||||||
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
|
||||||
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil {
|
if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
|
||||||
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
|
|
||||||
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,10 +38,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/sleep"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||||
"github.com/netbirdio/netbird/client/ui/event"
|
"github.com/netbirdio/netbird/client/ui/event"
|
||||||
"github.com/netbirdio/netbird/client/ui/notifier"
|
|
||||||
"github.com/netbirdio/netbird/client/ui/process"
|
"github.com/netbirdio/netbird/client/ui/process"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
@@ -260,7 +260,6 @@ type serviceClient struct {
|
|||||||
|
|
||||||
// application with main windows.
|
// application with main windows.
|
||||||
app fyne.App
|
app fyne.App
|
||||||
notifier notifier.Notifier
|
|
||||||
wSettings fyne.Window
|
wSettings fyne.Window
|
||||||
showAdvancedSettings bool
|
showAdvancedSettings bool
|
||||||
sendNotification bool
|
sendNotification bool
|
||||||
@@ -365,7 +364,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
|
|||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
addr: args.addr,
|
addr: args.addr,
|
||||||
app: args.app,
|
app: args.app,
|
||||||
notifier: notifier.New(args.app),
|
|
||||||
logFile: args.logFile,
|
logFile: args.logFile,
|
||||||
sendNotification: false,
|
sendNotification: false,
|
||||||
|
|
||||||
@@ -894,7 +892,7 @@ func (s *serviceClient) updateStatus() error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("get service status: %v", err)
|
log.Errorf("get service status: %v", err)
|
||||||
if s.connected {
|
if s.connected {
|
||||||
s.notifier.Send("Error", "Connection to service lost")
|
s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost"))
|
||||||
}
|
}
|
||||||
s.setDisconnectedStatus()
|
s.setDisconnectedStatus()
|
||||||
return err
|
return err
|
||||||
@@ -1111,7 +1109,7 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
s.eventManager = event.NewManager(s.notifier, s.addr)
|
s.eventManager = event.NewManager(s.app, s.addr)
|
||||||
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
|
||||||
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
s.eventManager.AddHandler(func(event *proto.SystemEvent) {
|
||||||
if event.Category == proto.SystemEvent_SYSTEM {
|
if event.Category == proto.SystemEvent_SYSTEM {
|
||||||
@@ -1148,6 +1146,9 @@ func (s *serviceClient) onTrayReady() {
|
|||||||
|
|
||||||
go s.eventManager.Start(s.ctx)
|
go s.eventManager.Start(s.ctx)
|
||||||
go s.eventHandler.listen(s.ctx)
|
go s.eventHandler.listen(s.ctx)
|
||||||
|
|
||||||
|
// Start sleep detection listener
|
||||||
|
go s.startSleepListener()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
|
||||||
@@ -1208,6 +1209,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
|
|||||||
return s.conn, nil
|
return s.conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// startSleepListener initializes the sleep detection service and listens for sleep events
|
||||||
|
func (s *serviceClient) startSleepListener() {
|
||||||
|
sleepService, err := sleep.New()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("%v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := sleepService.Register(s.handleSleepEvents); err != nil {
|
||||||
|
log.Errorf("failed to start sleep detection: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("sleep detection service initialized")
|
||||||
|
|
||||||
|
// Cleanup on context cancellation
|
||||||
|
go func() {
|
||||||
|
<-s.ctx.Done()
|
||||||
|
log.Info("stopping sleep event listener")
|
||||||
|
if err := sleepService.Deregister(); err != nil {
|
||||||
|
log.Errorf("failed to deregister sleep detection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleSleepEvents sends a sleep notification to the daemon via gRPC
|
||||||
|
func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
|
||||||
|
conn, err := s.getSrvClient(0)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get daemon client for sleep notification: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &proto.OSLifecycleRequest{}
|
||||||
|
|
||||||
|
switch event {
|
||||||
|
case sleep.EventTypeWakeUp:
|
||||||
|
log.Infof("handle wakeup event: %v", event)
|
||||||
|
req.Type = proto.OSLifecycleRequest_WAKEUP
|
||||||
|
case sleep.EventTypeSleep:
|
||||||
|
log.Infof("handle sleep event: %v", event)
|
||||||
|
req.Type = proto.OSLifecycleRequest_SLEEP
|
||||||
|
default:
|
||||||
|
log.Infof("unknown event: %v", event)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = conn.NotifyOSLifecycle(s.ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("successfully notified daemon about os lifecycle")
|
||||||
|
}
|
||||||
|
|
||||||
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
// setSettingsEnabled enables or disables the settings menu based on the provided state
|
||||||
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
func (s *serviceClient) setSettingsEnabled(enabled bool) {
|
||||||
if s.mSettings != nil {
|
if s.mSettings != nil {
|
||||||
@@ -1491,7 +1548,7 @@ func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) {
|
|||||||
|
|
||||||
if enforced && s.lastNotifiedVersion != newVersion {
|
if enforced && s.lastNotifiedVersion != newVersion {
|
||||||
s.lastNotifiedVersion = newVersion
|
s.lastNotifiedVersion = newVersion
|
||||||
s.notifier.Send("Update available", "A new version "+newVersion+" is ready to install")
|
s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import (
|
|||||||
"fyne.io/fyne/v2/widget"
|
"fyne.io/fyne/v2/widget"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/skratchdot/open-golang/open"
|
"github.com/skratchdot/open-golang/open"
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
@@ -39,7 +38,6 @@ type debugCollectionParams struct {
|
|||||||
upload bool
|
upload bool
|
||||||
uploadURL string
|
uploadURL string
|
||||||
enablePersistence bool
|
enablePersistence bool
|
||||||
capture bool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// UI components for progress tracking
|
// UI components for progress tracking
|
||||||
@@ -53,58 +51,25 @@ type progressUI struct {
|
|||||||
func (s *serviceClient) showDebugUI() {
|
func (s *serviceClient) showDebugUI() {
|
||||||
w := s.app.NewWindow("NetBird Debug")
|
w := s.app.NewWindow("NetBird Debug")
|
||||||
w.SetOnClosed(s.cancel)
|
w.SetOnClosed(s.cancel)
|
||||||
|
|
||||||
w.Resize(fyne.NewSize(600, 500))
|
w.Resize(fyne.NewSize(600, 500))
|
||||||
w.SetFixedSize(true)
|
w.SetFixedSize(true)
|
||||||
|
|
||||||
anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil)
|
anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil)
|
||||||
systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil)
|
systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil)
|
||||||
systemInfoCheck.SetChecked(true)
|
systemInfoCheck.SetChecked(true)
|
||||||
captureCheck := widget.NewCheck("Include packet capture", nil)
|
|
||||||
uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil)
|
uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil)
|
||||||
uploadCheck.SetChecked(true)
|
uploadCheck.SetChecked(true)
|
||||||
|
|
||||||
uploadURLContainer, uploadURL := s.buildUploadSection(uploadCheck)
|
uploadURLLabel := widget.NewLabel("Debug upload URL:")
|
||||||
|
|
||||||
debugModeContainer, runForDurationCheck, durationInput, noteLabel := s.buildDurationSection()
|
|
||||||
|
|
||||||
statusLabel := widget.NewLabel("")
|
|
||||||
statusLabel.Hide()
|
|
||||||
progressBar := widget.NewProgressBar()
|
|
||||||
progressBar.Hide()
|
|
||||||
createButton := widget.NewButton("Create Debug Bundle", nil)
|
|
||||||
|
|
||||||
uiControls := []fyne.Disableable{
|
|
||||||
anonymizeCheck, systemInfoCheck, captureCheck,
|
|
||||||
uploadCheck, uploadURL, runForDurationCheck, durationInput, createButton,
|
|
||||||
}
|
|
||||||
|
|
||||||
createButton.OnTapped = s.getCreateHandler(
|
|
||||||
statusLabel, progressBar, uploadCheck, uploadURL,
|
|
||||||
anonymizeCheck, systemInfoCheck, captureCheck,
|
|
||||||
runForDurationCheck, durationInput, uiControls, w,
|
|
||||||
)
|
|
||||||
|
|
||||||
content := container.NewVBox(
|
|
||||||
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
|
|
||||||
widget.NewLabel(""),
|
|
||||||
anonymizeCheck, systemInfoCheck, captureCheck,
|
|
||||||
uploadCheck, uploadURLContainer,
|
|
||||||
widget.NewLabel(""),
|
|
||||||
debugModeContainer, noteLabel,
|
|
||||||
widget.NewLabel(""),
|
|
||||||
statusLabel, progressBar, createButton,
|
|
||||||
)
|
|
||||||
|
|
||||||
w.SetContent(container.NewPadded(content))
|
|
||||||
w.Show()
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Container, *widget.Entry) {
|
|
||||||
uploadURL := widget.NewEntry()
|
uploadURL := widget.NewEntry()
|
||||||
uploadURL.SetText(uptypes.DefaultBundleURL)
|
uploadURL.SetText(uptypes.DefaultBundleURL)
|
||||||
uploadURL.SetPlaceHolder("Enter upload URL")
|
uploadURL.SetPlaceHolder("Enter upload URL")
|
||||||
|
|
||||||
uploadURLContainer := container.NewVBox(widget.NewLabel("Debug upload URL:"), uploadURL)
|
uploadURLContainer := container.NewVBox(
|
||||||
|
uploadURLLabel,
|
||||||
|
uploadURL,
|
||||||
|
)
|
||||||
|
|
||||||
uploadCheck.OnChanged = func(checked bool) {
|
uploadCheck.OnChanged = func(checked bool) {
|
||||||
if checked {
|
if checked {
|
||||||
@@ -113,14 +78,13 @@ func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Con
|
|||||||
uploadURLContainer.Hide()
|
uploadURLContainer.Hide()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return uploadURLContainer, uploadURL
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check, *widget.Entry, *widget.Label) {
|
debugModeContainer := container.NewHBox()
|
||||||
runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil)
|
runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil)
|
||||||
runForDurationCheck.SetChecked(true)
|
runForDurationCheck.SetChecked(true)
|
||||||
|
|
||||||
forLabel := widget.NewLabel("for")
|
forLabel := widget.NewLabel("for")
|
||||||
|
|
||||||
durationInput := widget.NewEntry()
|
durationInput := widget.NewEntry()
|
||||||
durationInput.SetText("1")
|
durationInput.SetText("1")
|
||||||
minutesLabel := widget.NewLabel("minute")
|
minutesLabel := widget.NewLabel("minute")
|
||||||
@@ -144,8 +108,63 @@ func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check,
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
modeContainer := container.NewHBox(runForDurationCheck, forLabel, durationInput, minutesLabel)
|
debugModeContainer.Add(runForDurationCheck)
|
||||||
return modeContainer, runForDurationCheck, durationInput, noteLabel
|
debugModeContainer.Add(forLabel)
|
||||||
|
debugModeContainer.Add(durationInput)
|
||||||
|
debugModeContainer.Add(minutesLabel)
|
||||||
|
|
||||||
|
statusLabel := widget.NewLabel("")
|
||||||
|
statusLabel.Hide()
|
||||||
|
|
||||||
|
progressBar := widget.NewProgressBar()
|
||||||
|
progressBar.Hide()
|
||||||
|
|
||||||
|
createButton := widget.NewButton("Create Debug Bundle", nil)
|
||||||
|
|
||||||
|
// UI controls that should be disabled during debug collection
|
||||||
|
uiControls := []fyne.Disableable{
|
||||||
|
anonymizeCheck,
|
||||||
|
systemInfoCheck,
|
||||||
|
uploadCheck,
|
||||||
|
uploadURL,
|
||||||
|
runForDurationCheck,
|
||||||
|
durationInput,
|
||||||
|
createButton,
|
||||||
|
}
|
||||||
|
|
||||||
|
createButton.OnTapped = s.getCreateHandler(
|
||||||
|
statusLabel,
|
||||||
|
progressBar,
|
||||||
|
uploadCheck,
|
||||||
|
uploadURL,
|
||||||
|
anonymizeCheck,
|
||||||
|
systemInfoCheck,
|
||||||
|
runForDurationCheck,
|
||||||
|
durationInput,
|
||||||
|
uiControls,
|
||||||
|
w,
|
||||||
|
)
|
||||||
|
|
||||||
|
content := container.NewVBox(
|
||||||
|
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
|
||||||
|
widget.NewLabel(""),
|
||||||
|
anonymizeCheck,
|
||||||
|
systemInfoCheck,
|
||||||
|
uploadCheck,
|
||||||
|
uploadURLContainer,
|
||||||
|
widget.NewLabel(""),
|
||||||
|
debugModeContainer,
|
||||||
|
noteLabel,
|
||||||
|
widget.NewLabel(""),
|
||||||
|
statusLabel,
|
||||||
|
progressBar,
|
||||||
|
createButton,
|
||||||
|
)
|
||||||
|
|
||||||
|
paddedContent := container.NewPadded(content)
|
||||||
|
w.SetContent(paddedContent)
|
||||||
|
|
||||||
|
w.Show()
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateMinute(s string, minutesLabel *widget.Label) error {
|
func validateMinute(s string, minutesLabel *widget.Label) error {
|
||||||
@@ -181,7 +200,6 @@ func (s *serviceClient) getCreateHandler(
|
|||||||
uploadURL *widget.Entry,
|
uploadURL *widget.Entry,
|
||||||
anonymizeCheck *widget.Check,
|
anonymizeCheck *widget.Check,
|
||||||
systemInfoCheck *widget.Check,
|
systemInfoCheck *widget.Check,
|
||||||
captureCheck *widget.Check,
|
|
||||||
runForDurationCheck *widget.Check,
|
runForDurationCheck *widget.Check,
|
||||||
duration *widget.Entry,
|
duration *widget.Entry,
|
||||||
uiControls []fyne.Disableable,
|
uiControls []fyne.Disableable,
|
||||||
@@ -204,7 +222,6 @@ func (s *serviceClient) getCreateHandler(
|
|||||||
params := &debugCollectionParams{
|
params := &debugCollectionParams{
|
||||||
anonymize: anonymizeCheck.Checked,
|
anonymize: anonymizeCheck.Checked,
|
||||||
systemInfo: systemInfoCheck.Checked,
|
systemInfo: systemInfoCheck.Checked,
|
||||||
capture: captureCheck.Checked,
|
|
||||||
upload: uploadCheck.Checked,
|
upload: uploadCheck.Checked,
|
||||||
uploadURL: url,
|
uploadURL: url,
|
||||||
enablePersistence: true,
|
enablePersistence: true,
|
||||||
@@ -236,7 +253,10 @@ func (s *serviceClient) getCreateHandler(
|
|||||||
|
|
||||||
statusLabel.SetText("Creating debug bundle...")
|
statusLabel.SetText("Creating debug bundle...")
|
||||||
go s.handleDebugCreation(
|
go s.handleDebugCreation(
|
||||||
params,
|
anonymizeCheck.Checked,
|
||||||
|
systemInfoCheck.Checked,
|
||||||
|
uploadCheck.Checked,
|
||||||
|
url,
|
||||||
statusLabel,
|
statusLabel,
|
||||||
uiControls,
|
uiControls,
|
||||||
w,
|
w,
|
||||||
@@ -351,7 +371,7 @@ func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time
|
|||||||
func (s *serviceClient) configureServiceForDebug(
|
func (s *serviceClient) configureServiceForDebug(
|
||||||
conn proto.DaemonServiceClient,
|
conn proto.DaemonServiceClient,
|
||||||
state *debugInitialState,
|
state *debugInitialState,
|
||||||
params *debugCollectionParams,
|
enablePersistence bool,
|
||||||
) {
|
) {
|
||||||
if state.wasDown {
|
if state.wasDown {
|
||||||
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
|
||||||
@@ -377,7 +397,7 @@ func (s *serviceClient) configureServiceForDebug(
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.enablePersistence {
|
if enablePersistence {
|
||||||
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
|
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
@@ -397,26 +417,6 @@ func (s *serviceClient) configureServiceForDebug(
|
|||||||
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
|
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
|
||||||
log.Warnf("failed to start CPU profiling: %v", err)
|
log.Warnf("failed to start CPU profiling: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
s.startBundleCaptureIfEnabled(conn, params)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *serviceClient) startBundleCaptureIfEnabled(conn proto.DaemonServiceClient, params *debugCollectionParams) {
|
|
||||||
if !params.capture {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
const maxCapture = 10 * time.Minute
|
|
||||||
timeout := params.duration + 30*time.Second
|
|
||||||
if timeout > maxCapture {
|
|
||||||
timeout = maxCapture
|
|
||||||
log.Warnf("packet capture clamped to %s (server maximum)", maxCapture)
|
|
||||||
}
|
|
||||||
if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{
|
|
||||||
Timeout: durationpb.New(timeout),
|
|
||||||
}); err != nil {
|
|
||||||
log.Warnf("failed to start bundle capture: %v", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) collectDebugData(
|
func (s *serviceClient) collectDebugData(
|
||||||
@@ -430,7 +430,7 @@ func (s *serviceClient) collectDebugData(
|
|||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
startProgressTracker(ctx, &wg, params.duration, progress)
|
startProgressTracker(ctx, &wg, params.duration, progress)
|
||||||
|
|
||||||
s.configureServiceForDebug(conn, state, params)
|
s.configureServiceForDebug(conn, state, params.enablePersistence)
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
progress.progressBar.Hide()
|
progress.progressBar.Hide()
|
||||||
@@ -440,14 +440,6 @@ func (s *serviceClient) collectDebugData(
|
|||||||
log.Warnf("failed to stop CPU profiling: %v", err)
|
log.Warnf("failed to stop CPU profiling: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if params.capture {
|
|
||||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
|
||||||
log.Warnf("failed to stop bundle capture: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -528,37 +520,18 @@ func handleError(progress *progressUI, errMsg string) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceClient) handleDebugCreation(
|
func (s *serviceClient) handleDebugCreation(
|
||||||
params *debugCollectionParams,
|
anonymize bool,
|
||||||
|
systemInfo bool,
|
||||||
|
upload bool,
|
||||||
|
uploadURL string,
|
||||||
statusLabel *widget.Label,
|
statusLabel *widget.Label,
|
||||||
uiControls []fyne.Disableable,
|
uiControls []fyne.Disableable,
|
||||||
w fyne.Window,
|
w fyne.Window,
|
||||||
) {
|
) {
|
||||||
conn, err := s.getSrvClient(failFastTimeout)
|
log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...",
|
||||||
if err != nil {
|
anonymize, systemInfo, upload)
|
||||||
log.Errorf("Failed to get client for debug: %v", err)
|
|
||||||
statusLabel.SetText(fmt.Sprintf("Error: %v", err))
|
|
||||||
enableUIControls(uiControls)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if params.capture {
|
resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL)
|
||||||
if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{
|
|
||||||
Timeout: durationpb.New(30 * time.Second),
|
|
||||||
}); err != nil {
|
|
||||||
log.Warnf("failed to start bundle capture: %v", err)
|
|
||||||
} else {
|
|
||||||
defer func() {
|
|
||||||
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
|
|
||||||
log.Warnf("failed to stop bundle capture: %v", err)
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
time.Sleep(2 * time.Second)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
resp, err := s.createDebugBundle(params.anonymize, params.systemInfo, params.uploadURL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("Failed to create debug bundle: %v", err)
|
log.Errorf("Failed to create debug bundle: %v", err)
|
||||||
statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err))
|
statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err))
|
||||||
@@ -570,7 +543,7 @@ func (s *serviceClient) handleDebugCreation(
|
|||||||
uploadFailureReason := resp.GetUploadFailureReason()
|
uploadFailureReason := resp.GetUploadFailureReason()
|
||||||
uploadedKey := resp.GetUploadedKey()
|
uploadedKey := resp.GetUploadedKey()
|
||||||
|
|
||||||
if params.upload {
|
if upload {
|
||||||
if uploadFailureReason != "" {
|
if uploadFailureReason != "" {
|
||||||
showUploadFailedDialog(w, localPath, uploadFailureReason)
|
showUploadFailedDialog(w, localPath, uploadFailureReason)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -8,6 +8,7 @@ import (
|
|||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"fyne.io/fyne/v2"
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
@@ -17,17 +18,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/ui/desktop"
|
"github.com/netbirdio/netbird/client/ui/desktop"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Notifier sends desktop notifications. Defined here so the event package
|
|
||||||
// does not depend on fyne or the platform-specific notifier implementation.
|
|
||||||
type Notifier interface {
|
|
||||||
Send(title, body string)
|
|
||||||
}
|
|
||||||
|
|
||||||
type Handler func(*proto.SystemEvent)
|
type Handler func(*proto.SystemEvent)
|
||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
notifier Notifier
|
app fyne.App
|
||||||
addr string
|
addr string
|
||||||
|
|
||||||
mu sync.Mutex
|
mu sync.Mutex
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -36,10 +31,10 @@ type Manager struct {
|
|||||||
handlers []Handler
|
handlers []Handler
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(notifier Notifier, addr string) *Manager {
|
func NewManager(app fyne.App, addr string) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
notifier: notifier,
|
app: app,
|
||||||
addr: addr,
|
addr: addr,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -119,7 +114,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
|
|||||||
if id != "" {
|
if id != "" {
|
||||||
body += fmt.Sprintf(" ID: %s", id)
|
body += fmt.Sprintf(" ID: %s", id)
|
||||||
}
|
}
|
||||||
e.notifier.Send(title, body)
|
e.app.SendNotification(fyne.NewNotification(title, body))
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, handler := range handlers {
|
for _, handler := range handlers {
|
||||||
|
|||||||
@@ -9,6 +9,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
|
|
||||||
|
"fyne.io/fyne/v2"
|
||||||
"fyne.io/systray"
|
"fyne.io/systray"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
@@ -86,7 +87,7 @@ func (h *eventHandler) handleConnectClick() {
|
|||||||
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
|
||||||
log.Debugf("connect operation cancelled by user")
|
log.Debugf("connect operation cancelled by user")
|
||||||
} else {
|
} else {
|
||||||
h.client.notifier.Send("Error", "Failed to connect")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
|
||||||
log.Errorf("connect failed: %v", err)
|
log.Errorf("connect failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -111,7 +112,7 @@ func (h *eventHandler) handleDisconnectClick() {
|
|||||||
if err := h.client.menuDownClick(); err != nil {
|
if err := h.client.menuDownClick(); err != nil {
|
||||||
st, ok := status.FromError(err)
|
st, ok := status.FromError(err)
|
||||||
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
|
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
|
||||||
h.client.notifier.Send("Error", "Failed to disconnect")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
|
||||||
log.Errorf("disconnect failed: %v", err)
|
log.Errorf("disconnect failed: %v", err)
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("disconnect cancelled or already disconnecting")
|
log.Debugf("disconnect cancelled or already disconnecting")
|
||||||
@@ -129,7 +130,7 @@ func (h *eventHandler) handleAllowSSHClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update SSH settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update SSH settings"))
|
||||||
}
|
}
|
||||||
|
|
||||||
}
|
}
|
||||||
@@ -139,7 +140,7 @@ func (h *eventHandler) handleAutoConnectClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update auto-connect settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update auto-connect settings"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -148,7 +149,7 @@ func (h *eventHandler) handleRosenpassClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update Rosenpass settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update Rosenpass settings"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +158,7 @@ func (h *eventHandler) handleLazyConnectionClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update lazy connection settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update lazy connection settings"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -166,7 +167,7 @@ func (h *eventHandler) handleBlockInboundClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update block inbound settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update block inbound settings"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +176,7 @@ func (h *eventHandler) handleNotificationsClick() {
|
|||||||
if err := h.updateConfigWithErr(); err != nil {
|
if err := h.updateConfigWithErr(); err != nil {
|
||||||
h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error
|
h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error
|
||||||
log.Errorf("failed to update config: %v", err)
|
log.Errorf("failed to update config: %v", err)
|
||||||
h.client.notifier.Send("Error", "Failed to update notifications settings")
|
h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update notifications settings"))
|
||||||
} else if h.client.eventManager != nil {
|
} else if h.client.eventManager != nil {
|
||||||
h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked())
|
h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,27 +0,0 @@
|
|||||||
// Package notifier sends desktop notifications. On Windows it uses the WinRT
|
|
||||||
// COM API directly via go-toast/v2 to avoid the PowerShell window flash that
|
|
||||||
// fyne's default implementation produces. On other platforms it delegates to
|
|
||||||
// fyne.
|
|
||||||
package notifier
|
|
||||||
|
|
||||||
import "fyne.io/fyne/v2"
|
|
||||||
|
|
||||||
// Notifier sends desktop notifications.
|
|
||||||
type Notifier interface {
|
|
||||||
Send(title, body string)
|
|
||||||
}
|
|
||||||
|
|
||||||
// New returns a platform-specific Notifier. The fyne app is used as the
|
|
||||||
// fallback notifier on platforms where no native implementation is wired up,
|
|
||||||
// and on Windows when the COM path fails to initialize.
|
|
||||||
func New(app fyne.App) Notifier {
|
|
||||||
return newNotifier(app)
|
|
||||||
}
|
|
||||||
|
|
||||||
type fyneNotifier struct {
|
|
||||||
app fyne.App
|
|
||||||
}
|
|
||||||
|
|
||||||
func (f *fyneNotifier) Send(title, body string) {
|
|
||||||
f.app.SendNotification(fyne.NewNotification(title, body))
|
|
||||||
}
|
|
||||||
@@ -1,9 +0,0 @@
|
|||||||
//go:build !windows
|
|
||||||
|
|
||||||
package notifier
|
|
||||||
|
|
||||||
import "fyne.io/fyne/v2"
|
|
||||||
|
|
||||||
func newNotifier(app fyne.App) Notifier {
|
|
||||||
return &fyneNotifier{app: app}
|
|
||||||
}
|
|
||||||
@@ -1,88 +0,0 @@
|
|||||||
package notifier
|
|
||||||
|
|
||||||
import (
|
|
||||||
"os"
|
|
||||||
"path/filepath"
|
|
||||||
"sync"
|
|
||||||
|
|
||||||
"fyne.io/fyne/v2"
|
|
||||||
toast "git.sr.ht/~jackmordaunt/go-toast/v2"
|
|
||||||
"git.sr.ht/~jackmordaunt/go-toast/v2/wintoast"
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
)
|
|
||||||
|
|
||||||
const (
|
|
||||||
// appID is the AppUserModelID shown in the Windows Action Center. It
|
|
||||||
// must match the System.AppUserModel.ID property set on the Start Menu
|
|
||||||
// shortcut by the MSI (see client/netbird.wxs); otherwise Windows
|
|
||||||
// groups toasts under a separate, unbranded entry.
|
|
||||||
appID = "NetBird"
|
|
||||||
|
|
||||||
// appGUID identifies the COM activation callback class. Generated once
|
|
||||||
// for NetBird; do not change without coordinating an installer bump,
|
|
||||||
// since old registry entries pointing at the previous GUID would orphan.
|
|
||||||
appGUID = "{0E1B4DE7-E148-432B-9814-544F941826EC}"
|
|
||||||
)
|
|
||||||
|
|
||||||
type comNotifier struct {
|
|
||||||
fallback *fyneNotifier
|
|
||||||
ready bool
|
|
||||||
iconPath string
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
|
||||||
initOnce sync.Once
|
|
||||||
initErr error
|
|
||||||
)
|
|
||||||
|
|
||||||
func newNotifier(app fyne.App) Notifier {
|
|
||||||
n := &comNotifier{
|
|
||||||
fallback: &fyneNotifier{app: app},
|
|
||||||
iconPath: resolveIcon(),
|
|
||||||
}
|
|
||||||
initOnce.Do(func() {
|
|
||||||
initErr = wintoast.SetAppData(wintoast.AppData{
|
|
||||||
AppID: appID,
|
|
||||||
GUID: appGUID,
|
|
||||||
IconPath: n.iconPath,
|
|
||||||
})
|
|
||||||
})
|
|
||||||
if initErr != nil {
|
|
||||||
log.Warnf("toast: register app data failed, falling back to fyne notifications: %v", initErr)
|
|
||||||
return n.fallback
|
|
||||||
}
|
|
||||||
n.ready = true
|
|
||||||
return n
|
|
||||||
}
|
|
||||||
|
|
||||||
func (n *comNotifier) Send(title, body string) {
|
|
||||||
if !n.ready {
|
|
||||||
n.fallback.Send(title, body)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
notification := toast.Notification{
|
|
||||||
AppID: appID,
|
|
||||||
Title: title,
|
|
||||||
Body: body,
|
|
||||||
Icon: n.iconPath,
|
|
||||||
}
|
|
||||||
if err := notification.Push(); err != nil {
|
|
||||||
log.Warnf("toast: push failed, using fyne fallback: %v", err)
|
|
||||||
n.fallback.Send(title, body)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// resolveIcon returns an absolute path to the toast icon, or an empty string
|
|
||||||
// when no icon can be located. Windows requires a PNG/JPG for the
|
|
||||||
// AppUserModelId IconUri registry value; .ico is silently ignored.
|
|
||||||
func resolveIcon() string {
|
|
||||||
exe, err := os.Executable()
|
|
||||||
if err != nil {
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
candidate := filepath.Join(filepath.Dir(exe), "netbird.png")
|
|
||||||
if _, err := os.Stat(candidate); err == nil {
|
|
||||||
return candidate
|
|
||||||
}
|
|
||||||
return ""
|
|
||||||
}
|
|
||||||
@@ -548,7 +548,7 @@ func (p *profileMenu) refresh() {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to switch profile: %v", err)
|
log.Errorf("failed to switch profile: %v", err)
|
||||||
// show notification dialog
|
// show notification dialog
|
||||||
p.serviceClient.notifier.Send("Error", "Failed to switch profile")
|
p.app.SendNotification(fyne.NewNotification("Error", "Failed to switch profile"))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -628,9 +628,9 @@ func (p *profileMenu) refresh() {
|
|||||||
}
|
}
|
||||||
if err := p.eventHandler.logout(p.ctx); err != nil {
|
if err := p.eventHandler.logout(p.ctx); err != nil {
|
||||||
log.Errorf("logout failed: %v", err)
|
log.Errorf("logout failed: %v", err)
|
||||||
p.serviceClient.notifier.Send("Error", "Failed to deregister")
|
p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
|
||||||
} else {
|
} else {
|
||||||
p.serviceClient.notifier.Send("Success", "Deregistered successfully")
|
p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@ package main
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
|
||||||
"syscall/js"
|
"syscall/js"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -15,7 +14,6 @@ import (
|
|||||||
netbird "github.com/netbirdio/netbird/client/embed"
|
netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
wasmcapture "github.com/netbirdio/netbird/client/wasm/internal/capture"
|
|
||||||
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
"github.com/netbirdio/netbird/client/wasm/internal/http"
|
||||||
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
"github.com/netbirdio/netbird/client/wasm/internal/rdp"
|
||||||
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
"github.com/netbirdio/netbird/client/wasm/internal/ssh"
|
||||||
@@ -461,95 +459,6 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
// createStartCaptureMethod creates the programmable packet capture method.
|
|
||||||
// Returns a JS interface with onpacket callback and stop() method.
|
|
||||||
//
|
|
||||||
// Usage from JavaScript:
|
|
||||||
//
|
|
||||||
// const cap = await client.startCapture({ filter: "tcp port 443", verbose: true })
|
|
||||||
// cap.onpacket = (line) => console.log(line)
|
|
||||||
// const stats = cap.stop()
|
|
||||||
func createStartCaptureMethod(client *netbird.Client) js.Func {
|
|
||||||
return js.FuncOf(func(_ js.Value, args []js.Value) any {
|
|
||||||
var opts js.Value
|
|
||||||
if len(args) > 0 {
|
|
||||||
opts = args[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return createPromise(func(resolve, reject js.Value) {
|
|
||||||
iface, err := wasmcapture.Start(client, opts)
|
|
||||||
if err != nil {
|
|
||||||
reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resolve.Invoke(iface)
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// captureMethods returns capture() and stopCapture() that share state for
|
|
||||||
// the console-log shortcut. capture() logs packets to the browser console
|
|
||||||
// and stopCapture() ends it, like Ctrl+C on the CLI.
|
|
||||||
//
|
|
||||||
// Usage from browser devtools console:
|
|
||||||
//
|
|
||||||
// await client.capture() // capture all packets
|
|
||||||
// await client.capture("tcp") // capture with filter
|
|
||||||
// await client.capture({filter: "host 10.0.0.1", verbose: true})
|
|
||||||
// client.stopCapture() // stop and print stats
|
|
||||||
func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) {
|
|
||||||
var mu sync.Mutex
|
|
||||||
var active *wasmcapture.Handle
|
|
||||||
|
|
||||||
startFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
|
|
||||||
var opts js.Value
|
|
||||||
if len(args) > 0 {
|
|
||||||
opts = args[0]
|
|
||||||
}
|
|
||||||
|
|
||||||
return createPromise(func(resolve, reject js.Value) {
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
|
|
||||||
if active != nil {
|
|
||||||
active.Stop()
|
|
||||||
active = nil
|
|
||||||
}
|
|
||||||
|
|
||||||
h, err := wasmcapture.StartConsole(client, opts)
|
|
||||||
if err != nil {
|
|
||||||
reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err)))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
active = h
|
|
||||||
|
|
||||||
console := js.Global().Get("console")
|
|
||||||
console.Call("log", "[capture] started, call client.stopCapture() to stop")
|
|
||||||
resolve.Invoke(js.Undefined())
|
|
||||||
})
|
|
||||||
})
|
|
||||||
|
|
||||||
stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
|
||||||
mu.Lock()
|
|
||||||
defer mu.Unlock()
|
|
||||||
|
|
||||||
if active == nil {
|
|
||||||
js.Global().Get("console").Call("log", "[capture] no active capture")
|
|
||||||
return js.Undefined()
|
|
||||||
}
|
|
||||||
|
|
||||||
stats := active.Stop()
|
|
||||||
active = nil
|
|
||||||
|
|
||||||
console := js.Global().Get("console")
|
|
||||||
console.Call("log", fmt.Sprintf("[capture] stopped: %d packets, %d bytes, %d dropped",
|
|
||||||
stats.Packets, stats.Bytes, stats.Dropped))
|
|
||||||
return js.Undefined()
|
|
||||||
})
|
|
||||||
|
|
||||||
return startFn, stopFn
|
|
||||||
}
|
|
||||||
|
|
||||||
// createPromise is a helper to create JavaScript promises
|
// createPromise is a helper to create JavaScript promises
|
||||||
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
func createPromise(handler func(resolve, reject js.Value)) js.Value {
|
||||||
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
|
||||||
@@ -612,11 +521,6 @@ func createClientObject(client *netbird.Client) js.Value {
|
|||||||
obj["statusDetail"] = createStatusDetailMethod(client)
|
obj["statusDetail"] = createStatusDetailMethod(client)
|
||||||
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
obj["getSyncResponse"] = createGetSyncResponseMethod(client)
|
||||||
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
obj["setLogLevel"] = createSetLogLevelMethod(client)
|
||||||
obj["startCapture"] = createStartCaptureMethod(client)
|
|
||||||
|
|
||||||
capStart, capStop := captureMethods(client)
|
|
||||||
obj["capture"] = capStart
|
|
||||||
obj["stopCapture"] = capStop
|
|
||||||
|
|
||||||
return js.ValueOf(obj)
|
return js.ValueOf(obj)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,176 +0,0 @@
|
|||||||
//go:build js
|
|
||||||
|
|
||||||
// Package capture bridges the util/capture package to JavaScript via syscall/js.
|
|
||||||
package capture
|
|
||||||
|
|
||||||
import (
|
|
||||||
"strings"
|
|
||||||
"sync"
|
|
||||||
"syscall/js"
|
|
||||||
|
|
||||||
netbird "github.com/netbirdio/netbird/client/embed"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Handle holds a running capture session so it can be stopped later.
|
|
||||||
type Handle struct {
|
|
||||||
cs *netbird.CaptureSession
|
|
||||||
stopFn js.Func
|
|
||||||
stopped bool
|
|
||||||
}
|
|
||||||
|
|
||||||
// Stop ends the capture and returns stats.
|
|
||||||
func (h *Handle) Stop() netbird.CaptureStats {
|
|
||||||
if h.stopped {
|
|
||||||
return h.cs.Stats()
|
|
||||||
}
|
|
||||||
h.stopped = true
|
|
||||||
h.stopFn.Release()
|
|
||||||
|
|
||||||
h.cs.Stop()
|
|
||||||
return h.cs.Stats()
|
|
||||||
}
|
|
||||||
|
|
||||||
func statsToJS(s netbird.CaptureStats) js.Value {
|
|
||||||
obj := js.Global().Get("Object").Call("create", js.Null())
|
|
||||||
obj.Set("packets", js.ValueOf(s.Packets))
|
|
||||||
obj.Set("bytes", js.ValueOf(s.Bytes))
|
|
||||||
obj.Set("dropped", js.ValueOf(s.Dropped))
|
|
||||||
return obj
|
|
||||||
}
|
|
||||||
|
|
||||||
// parseOpts extracts filter/verbose/ascii from a JS options value.
|
|
||||||
func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) {
|
|
||||||
if jsOpts.IsNull() || jsOpts.IsUndefined() {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if jsOpts.Type() == js.TypeString {
|
|
||||||
filter = jsOpts.String()
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if jsOpts.Type() != js.TypeObject {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if f := jsOpts.Get("filter"); !f.IsUndefined() && !f.IsNull() {
|
|
||||||
filter = f.String()
|
|
||||||
}
|
|
||||||
if v := jsOpts.Get("verbose"); !v.IsUndefined() {
|
|
||||||
verbose = v.Truthy()
|
|
||||||
}
|
|
||||||
if a := jsOpts.Get("ascii"); !a.IsUndefined() {
|
|
||||||
ascii = a.Truthy()
|
|
||||||
}
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start creates a capture session and returns a JS interface for streaming text
|
|
||||||
// output. The returned object exposes:
|
|
||||||
//
|
|
||||||
// onpacket(callback) - set callback(string) for each text line
|
|
||||||
// stop() - stop capture and return stats { packets, bytes, dropped }
|
|
||||||
//
|
|
||||||
// Options: { filter: string, verbose: bool, ascii: bool } or just a filter string.
|
|
||||||
func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) {
|
|
||||||
filter, verbose, ascii := parseOpts(jsOpts)
|
|
||||||
|
|
||||||
cb := &jsCallbackWriter{}
|
|
||||||
|
|
||||||
cs, err := client.StartCapture(netbird.CaptureOptions{
|
|
||||||
TextOutput: cb,
|
|
||||||
Filter: filter,
|
|
||||||
Verbose: verbose,
|
|
||||||
ASCII: ascii,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return js.Undefined(), err
|
|
||||||
}
|
|
||||||
|
|
||||||
handle := &Handle{cs: cs}
|
|
||||||
|
|
||||||
iface := js.Global().Get("Object").Call("create", js.Null())
|
|
||||||
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
|
||||||
return statsToJS(handle.Stop())
|
|
||||||
})
|
|
||||||
iface.Set("stop", handle.stopFn)
|
|
||||||
iface.Set("onpacket", js.Undefined())
|
|
||||||
cb.setInterface(iface)
|
|
||||||
|
|
||||||
return iface, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// StartConsole starts a capture that logs every packet line to console.log.
|
|
||||||
// Returns a Handle so the caller can stop it later.
|
|
||||||
func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) {
|
|
||||||
filter, verbose, ascii := parseOpts(jsOpts)
|
|
||||||
|
|
||||||
cb := &jsCallbackWriter{}
|
|
||||||
|
|
||||||
cs, err := client.StartCapture(netbird.CaptureOptions{
|
|
||||||
TextOutput: cb,
|
|
||||||
Filter: filter,
|
|
||||||
Verbose: verbose,
|
|
||||||
ASCII: ascii,
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
handle := &Handle{cs: cs}
|
|
||||||
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
|
|
||||||
return statsToJS(handle.Stop())
|
|
||||||
})
|
|
||||||
|
|
||||||
iface := js.Global().Get("Object").Call("create", js.Null())
|
|
||||||
console := js.Global().Get("console")
|
|
||||||
iface.Set("onpacket", console.Get("log").Call("bind", console, js.ValueOf("[capture]")))
|
|
||||||
cb.setInterface(iface)
|
|
||||||
|
|
||||||
return handle, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// jsCallbackWriter is an io.Writer that buffers text until a newline, then
|
|
||||||
// invokes the JS onpacket callback with each complete line.
|
|
||||||
type jsCallbackWriter struct {
|
|
||||||
mu sync.Mutex
|
|
||||||
iface js.Value
|
|
||||||
buf strings.Builder
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *jsCallbackWriter) setInterface(iface js.Value) {
|
|
||||||
w.mu.Lock()
|
|
||||||
defer w.mu.Unlock()
|
|
||||||
w.iface = iface
|
|
||||||
}
|
|
||||||
|
|
||||||
func (w *jsCallbackWriter) Write(p []byte) (int, error) {
|
|
||||||
w.mu.Lock()
|
|
||||||
w.buf.Write(p)
|
|
||||||
|
|
||||||
var lines []string
|
|
||||||
for {
|
|
||||||
str := w.buf.String()
|
|
||||||
idx := strings.IndexByte(str, '\n')
|
|
||||||
if idx < 0 {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
lines = append(lines, str[:idx])
|
|
||||||
w.buf.Reset()
|
|
||||||
if idx+1 < len(str) {
|
|
||||||
w.buf.WriteString(str[idx+1:])
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
iface := w.iface
|
|
||||||
w.mu.Unlock()
|
|
||||||
|
|
||||||
if iface.IsUndefined() {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
cb := iface.Get("onpacket")
|
|
||||||
if cb.IsUndefined() || cb.IsNull() {
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
for _, line := range lines {
|
|
||||||
cb.Invoke(js.ValueOf(line))
|
|
||||||
}
|
|
||||||
return len(p), nil
|
|
||||||
}
|
|
||||||
@@ -13,9 +13,11 @@ import (
|
|||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
|
"google.golang.org/grpc/codes"
|
||||||
"google.golang.org/grpc/credentials"
|
"google.golang.org/grpc/credentials"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
"google.golang.org/grpc/keepalive"
|
"google.golang.org/grpc/keepalive"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
nbgrpc "github.com/netbirdio/netbird/client/grpc"
|
||||||
"github.com/netbirdio/netbird/flow/proto"
|
"github.com/netbirdio/netbird/flow/proto"
|
||||||
@@ -299,11 +301,12 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
|
|||||||
}, ctx)
|
}, ctx)
|
||||||
}
|
}
|
||||||
|
|
||||||
// isContextDone reports whether the local context has been canceled or has
|
|
||||||
// exceeded its deadline. It deliberately does not inspect gRPC status codes:
|
|
||||||
// a server- or proxy-sent codes.Canceled / codes.DeadlineExceeded must not
|
|
||||||
// short-circuit our retry loop, since retrying is the correct response when
|
|
||||||
// the local context is still alive.
|
|
||||||
func isContextDone(err error) bool {
|
func isContextDone(err error) bool {
|
||||||
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded)
|
if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
if s, ok := status.FromError(err); ok {
|
||||||
|
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
|
||||||
|
}
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|||||||
5
go.mod
5
go.mod
@@ -30,7 +30,6 @@ require (
|
|||||||
require (
|
require (
|
||||||
fyne.io/fyne/v2 v2.7.0
|
fyne.io/fyne/v2 v2.7.0
|
||||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
|
||||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
|
|
||||||
github.com/awnumar/memguard v0.23.0
|
github.com/awnumar/memguard v0.23.0
|
||||||
github.com/aws/aws-sdk-go-v2 v1.38.3
|
github.com/aws/aws-sdk-go-v2 v1.38.3
|
||||||
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
github.com/aws/aws-sdk-go-v2/config v1.31.6
|
||||||
@@ -47,7 +46,6 @@ require (
|
|||||||
github.com/crowdsecurity/go-cs-bouncer v0.0.21
|
github.com/crowdsecurity/go-cs-bouncer v0.0.21
|
||||||
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
github.com/dexidp/dex v0.0.0-00010101000000-000000000000
|
||||||
github.com/dexidp/dex/api/v2 v2.4.0
|
github.com/dexidp/dex/api/v2 v2.4.0
|
||||||
github.com/ebitengine/purego v0.8.4
|
|
||||||
github.com/eko/gocache/lib/v4 v4.2.0
|
github.com/eko/gocache/lib/v4 v4.2.0
|
||||||
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
github.com/eko/gocache/store/go_cache/v4 v4.2.2
|
||||||
github.com/eko/gocache/store/redis/v4 v4.2.2
|
github.com/eko/gocache/store/redis/v4 v4.2.2
|
||||||
@@ -71,7 +69,6 @@ require (
|
|||||||
github.com/libp2p/go-netroute v0.2.1
|
github.com/libp2p/go-netroute v0.2.1
|
||||||
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
|
||||||
github.com/mdlayher/socket v0.5.1
|
github.com/mdlayher/socket v0.5.1
|
||||||
github.com/mdp/qrterminal/v3 v3.2.1
|
|
||||||
github.com/miekg/dns v1.1.59
|
github.com/miekg/dns v1.1.59
|
||||||
github.com/mitchellh/hashstructure/v2 v2.0.2
|
github.com/mitchellh/hashstructure/v2 v2.0.2
|
||||||
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
|
||||||
@@ -181,6 +178,7 @@ require (
|
|||||||
github.com/docker/docker v28.0.1+incompatible // indirect
|
github.com/docker/docker v28.0.1+incompatible // indirect
|
||||||
github.com/docker/go-connections v0.6.0 // indirect
|
github.com/docker/go-connections v0.6.0 // indirect
|
||||||
github.com/docker/go-units v0.5.0 // indirect
|
github.com/docker/go-units v0.5.0 // indirect
|
||||||
|
github.com/ebitengine/purego v0.8.4 // indirect
|
||||||
github.com/felixge/httpsnoop v1.0.4 // indirect
|
github.com/felixge/httpsnoop v1.0.4 // indirect
|
||||||
github.com/fredbi/uri v1.1.1 // indirect
|
github.com/fredbi/uri v1.1.1 // indirect
|
||||||
github.com/fyne-io/gl-js v0.2.0 // indirect
|
github.com/fyne-io/gl-js v0.2.0 // indirect
|
||||||
@@ -310,7 +308,6 @@ require (
|
|||||||
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
gopkg.in/square/go-jose.v2 v2.6.0 // indirect
|
||||||
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
|
||||||
gopkg.in/yaml.v2 v2.4.0 // indirect
|
gopkg.in/yaml.v2 v2.4.0 // indirect
|
||||||
rsc.io/qr v0.2.0 // indirect
|
|
||||||
)
|
)
|
||||||
|
|
||||||
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502
|
||||||
|
|||||||
6
go.sum
6
go.sum
@@ -15,8 +15,6 @@ fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
|
|||||||
fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
|
fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
|
||||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4=
|
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4=
|
||||||
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
|
||||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 h1:N3IGoHHp9pb6mj1cbXbuaSXV/UMKwmbKLf53nQmtqMA=
|
|
||||||
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3/go.mod h1:QtOLZGz8olr4qH2vWK0QH0w0O4T9fEIjMuWpKUsH7nc=
|
|
||||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
|
||||||
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
|
||||||
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
|
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
|
||||||
@@ -415,8 +413,6 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0
|
|||||||
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
|
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
|
||||||
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
|
||||||
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
|
||||||
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
|
|
||||||
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
|
|
||||||
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
|
||||||
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
|
||||||
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
|
||||||
@@ -917,5 +913,3 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
|
|||||||
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
|
||||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
|
||||||
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
|
||||||
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
|
|
||||||
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=
|
|
||||||
|
|||||||
@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
|
|||||||
// are stored with types that Dex can open.
|
// are stored with types that Dex can open.
|
||||||
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
|
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
|
||||||
switch connType {
|
switch connType {
|
||||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||||
return "oidc", applyOIDCDefaults(connType, config)
|
return "oidc", applyOIDCDefaults(connType, config)
|
||||||
default:
|
default:
|
||||||
return connType, config
|
return connType, config
|
||||||
@@ -218,8 +218,6 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
|
|||||||
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
|
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
|
||||||
case "okta", "pocketid":
|
case "okta", "pocketid":
|
||||||
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
|
augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
case "adfs":
|
|
||||||
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return augmented
|
return augmented
|
||||||
|
|||||||
@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
|
|||||||
var err error
|
var err error
|
||||||
|
|
||||||
switch cfg.Type {
|
switch cfg.Type {
|
||||||
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs":
|
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
|
||||||
dexType = "oidc"
|
dexType = "oidc"
|
||||||
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
|
configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
|
||||||
case "google":
|
case "google":
|
||||||
@@ -220,8 +220,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
|
|||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
case "pocketid":
|
case "pocketid":
|
||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
|
||||||
case "adfs":
|
|
||||||
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
|
|
||||||
}
|
}
|
||||||
return encodeConnectorConfig(oidcConfig)
|
return encodeConnectorConfig(oidcConfig)
|
||||||
}
|
}
|
||||||
@@ -285,7 +283,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
|
|||||||
// inferOIDCProviderType infers the specific OIDC provider from connector ID
|
// inferOIDCProviderType infers the specific OIDC provider from connector ID
|
||||||
func inferOIDCProviderType(connectorID string) string {
|
func inferOIDCProviderType(connectorID string) string {
|
||||||
connectorIDLower := strings.ToLower(connectorID)
|
connectorIDLower := strings.ToLower(connectorID)
|
||||||
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} {
|
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
|
||||||
if strings.Contains(connectorIDLower, provider) {
|
if strings.Contains(connectorIDLower, provider) {
|
||||||
return provider
|
return provider
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -231,20 +231,7 @@ get_upstream_host() {
|
|||||||
|
|
||||||
wait_management_proxy() {
|
wait_management_proxy() {
|
||||||
local proxy_container="${1:-traefik}"
|
local proxy_container="${1:-traefik}"
|
||||||
local use_docker_logs=false
|
|
||||||
set +e
|
set +e
|
||||||
|
|
||||||
if [[ "$proxy_container" == "detect-traefik" ]]; then
|
|
||||||
proxy_container=$(docker ps --format "{{.ID}}\t{{.Image}}\t{{.Ports}}" \
|
|
||||||
| awk -F'\t' '$2 ~ /traefik/ && $3 ~ /:(80|443)->/ {print $1; exit}')
|
|
||||||
|
|
||||||
if [[ -z "$proxy_container" ]]; then
|
|
||||||
echo "Warning: could not auto-detect Traefik container, log output will be skipped on timeout." > /dev/stderr
|
|
||||||
else
|
|
||||||
use_docker_logs=true
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
|
|
||||||
echo -n "Waiting for NetBird server to become ready"
|
echo -n "Waiting for NetBird server to become ready"
|
||||||
counter=1
|
counter=1
|
||||||
while true; do
|
while true; do
|
||||||
@@ -255,13 +242,7 @@ wait_management_proxy() {
|
|||||||
if [[ $counter -eq 60 ]]; then
|
if [[ $counter -eq 60 ]]; then
|
||||||
echo ""
|
echo ""
|
||||||
echo "Taking too long. Checking logs..."
|
echo "Taking too long. Checking logs..."
|
||||||
if [[ -n "$proxy_container" ]]; then
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
|
||||||
if [[ "$use_docker_logs" == "true" ]]; then
|
|
||||||
docker logs --tail=20 "$proxy_container"
|
|
||||||
else
|
|
||||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
|
|
||||||
fi
|
|
||||||
fi
|
|
||||||
$DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server
|
$DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server
|
||||||
fi
|
fi
|
||||||
echo -n " ."
|
echo -n " ."
|
||||||
@@ -537,7 +518,7 @@ start_services_and_show_instructions() {
|
|||||||
$DOCKER_COMPOSE_COMMAND up -d
|
$DOCKER_COMPOSE_COMMAND up -d
|
||||||
|
|
||||||
sleep 3
|
sleep 3
|
||||||
wait_management_proxy detect-traefik
|
wait_management_direct
|
||||||
|
|
||||||
echo -e "$MSG_DONE"
|
echo -e "$MSG_DONE"
|
||||||
print_post_setup_instructions
|
print_post_setup_instructions
|
||||||
|
|||||||
@@ -7,6 +7,7 @@ import (
|
|||||||
"os"
|
"os"
|
||||||
"slices"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -15,9 +16,11 @@ import (
|
|||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
"golang.org/x/mod/semver"
|
"golang.org/x/mod/semver"
|
||||||
|
|
||||||
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map"
|
||||||
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
|
||||||
|
"github.com/netbirdio/netbird/management/internals/modules/zones"
|
||||||
"github.com/netbirdio/netbird/management/internals/server/config"
|
"github.com/netbirdio/netbird/management/internals/server/config"
|
||||||
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
"github.com/netbirdio/netbird/management/internals/shared/grpc"
|
||||||
"github.com/netbirdio/netbird/management/server/account"
|
"github.com/netbirdio/netbird/management/server/account"
|
||||||
@@ -55,6 +58,13 @@ type Controller struct {
|
|||||||
proxyController port_forwarding.Controller
|
proxyController port_forwarding.Controller
|
||||||
|
|
||||||
integratedPeerValidator integrated_validator.IntegratedValidator
|
integratedPeerValidator integrated_validator.IntegratedValidator
|
||||||
|
|
||||||
|
holder *types.Holder
|
||||||
|
|
||||||
|
expNewNetworkMap bool
|
||||||
|
expNewNetworkMapAIDs map[string]struct{}
|
||||||
|
|
||||||
|
compactedNetworkMap bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type bufferUpdate struct {
|
type bufferUpdate struct {
|
||||||
@@ -71,6 +81,29 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
log.Fatal(fmt.Errorf("error creating metrics: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
|
||||||
|
newNetworkMapBuilder = false
|
||||||
|
}
|
||||||
|
|
||||||
|
compactedNetworkMap := true
|
||||||
|
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
|
||||||
|
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
|
||||||
|
if err != nil && len(compactedEnv) > 0 {
|
||||||
|
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
|
||||||
|
}
|
||||||
|
if err == nil && !parsedCompactedNmap {
|
||||||
|
log.WithContext(ctx).Info("disabling compacted mode")
|
||||||
|
compactedNetworkMap = false
|
||||||
|
}
|
||||||
|
|
||||||
|
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
|
||||||
|
expIDs := make(map[string]struct{}, len(ids))
|
||||||
|
for _, id := range ids {
|
||||||
|
expIDs[id] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
return &Controller{
|
return &Controller{
|
||||||
repo: newRepository(store),
|
repo: newRepository(store),
|
||||||
metrics: nMetrics,
|
metrics: nMetrics,
|
||||||
@@ -84,6 +117,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
|
|||||||
|
|
||||||
proxyController: proxyController,
|
proxyController: proxyController,
|
||||||
EphemeralPeersManager: ephemeralPeersManager,
|
EphemeralPeersManager: ephemeralPeersManager,
|
||||||
|
|
||||||
|
holder: types.NewHolder(),
|
||||||
|
expNewNetworkMap: newNetworkMapBuilder,
|
||||||
|
expNewNetworkMapAIDs: expIDs,
|
||||||
|
|
||||||
|
compactedNetworkMap: compactedNetworkMap,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -114,9 +153,17 @@ func (c *Controller) CountStreams() int {
|
|||||||
|
|
||||||
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
var (
|
||||||
if err != nil {
|
account *types.Account
|
||||||
return fmt.Errorf("failed to get account: %v", err)
|
err error
|
||||||
|
)
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||||
|
} else {
|
||||||
|
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get account: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
globalStart := time.Now()
|
globalStart := time.Now()
|
||||||
@@ -150,6 +197,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
routers := account.GetResourceRoutersMap()
|
routers := account.GetResourceRoutersMap()
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
|
||||||
|
}
|
||||||
|
|
||||||
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
|
||||||
@@ -192,7 +243,16 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
|
|||||||
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
c.metrics.CountCalcPostureChecksDuration(time.Since(start))
|
||||||
start = time.Now()
|
start = time.Now()
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case c.experimentalNetworkMap(accountID):
|
||||||
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
|
case c.compactedNetworkMap:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
default:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
}
|
||||||
|
|
||||||
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
|
||||||
|
|
||||||
@@ -257,10 +317,11 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
|
|||||||
|
|
||||||
// UpdatePeers updates all peers that belong to an account.
|
// UpdatePeers updates all peers that belong to an account.
|
||||||
// Should be called when changes have to be synced to peers.
|
// Should be called when changes have to be synced to peers.
|
||||||
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
if c.accountManagerMetrics != nil {
|
if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
|
||||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
return fmt.Errorf("recalculate network map cache: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.sendUpdateAccountPeers(ctx, accountID)
|
return c.sendUpdateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -310,7 +371,16 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
var remotePeerNetworkMap *types.NetworkMap
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case c.experimentalNetworkMap(accountId):
|
||||||
|
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
|
case c.compactedNetworkMap:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
default:
|
||||||
|
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -334,13 +404,9 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
|
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
|
||||||
|
|
||||||
if c.accountManagerMetrics != nil {
|
|
||||||
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
|
|
||||||
}
|
|
||||||
|
|
||||||
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
|
||||||
b := bufUpd.(*bufferUpdate)
|
b := bufUpd.(*bufferUpdate)
|
||||||
|
|
||||||
@@ -355,14 +421,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer b.mu.Unlock()
|
defer b.mu.Unlock()
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||||
if !b.update.Load() {
|
if !b.update.Load() {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
b.update.Store(false)
|
b.update.Store(false)
|
||||||
if b.next == nil {
|
if b.next == nil {
|
||||||
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
|
||||||
_ = c.sendUpdateAccountPeers(ctx, accountID)
|
_ = c.UpdateAccountPeers(ctx, accountID)
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -385,9 +451,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return peer, emptyMap, nil, 0, nil
|
return peer, emptyMap, nil, 0, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
var (
|
||||||
if err != nil {
|
account *types.Account
|
||||||
return nil, nil, nil, 0, err
|
err error
|
||||||
|
)
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
account = c.getAccountFromHolderOrInit(ctx, accountID)
|
||||||
|
} else {
|
||||||
|
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, nil, nil, 0, err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
account.InjectProxyPolicies(ctx)
|
account.InjectProxyPolicies(ctx)
|
||||||
@@ -419,10 +493,20 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return nil, nil, nil, 0, err
|
return nil, nil, nil, 0, err
|
||||||
}
|
}
|
||||||
|
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
var networkMap *types.NetworkMap
|
||||||
routers := account.GetResourceRoutersMap()
|
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
if c.experimentalNetworkMap(accountID) {
|
||||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
|
||||||
|
} else {
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
if c.compactedNetworkMap {
|
||||||
|
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
} else {
|
||||||
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
@@ -434,6 +518,108 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
|
|||||||
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
return peer, networkMap, postureChecks, dnsFwdPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
|
||||||
|
c.enrichAccountFromHolder(account)
|
||||||
|
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) getPeerNetworkMapExp(
|
||||||
|
ctx context.Context,
|
||||||
|
accountId string,
|
||||||
|
peerId string,
|
||||||
|
validatedPeers map[string]struct{},
|
||||||
|
peersCustomZone nbdns.CustomZone,
|
||||||
|
accountZones []*zones.Zone,
|
||||||
|
metrics *telemetry.AccountManagerMetrics,
|
||||||
|
) *types.NetworkMap {
|
||||||
|
account := c.getAccountFromHolderOrInit(ctx, accountId)
|
||||||
|
if account == nil {
|
||||||
|
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
|
||||||
|
return &types.NetworkMap{
|
||||||
|
Network: &types.Network{},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
|
||||||
|
c.enrichAccountFromHolder(account)
|
||||||
|
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
|
||||||
|
c.enrichAccountFromHolder(account)
|
||||||
|
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
|
||||||
|
account := c.getAccountFromHolder(accountId)
|
||||||
|
if account == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account.UpdatePeerInNetworkMapCache(peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
|
||||||
|
account.RecalculateNetworkMapCache(validatedPeers)
|
||||||
|
c.updateAccountInHolder(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
|
||||||
|
if c.experimentalNetworkMap(accountId) {
|
||||||
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
c.recalculateNetworkMapCache(account, validatedPeers)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) experimentalNetworkMap(accountId string) bool {
|
||||||
|
_, ok := c.expNewNetworkMapAIDs[accountId]
|
||||||
|
return c.expNewNetworkMap || ok
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
|
||||||
|
a := c.holder.GetAccount(account.Id)
|
||||||
|
if a == nil {
|
||||||
|
c.holder.AddAccount(account)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
account.NetworkMapCache = a.NetworkMapCache
|
||||||
|
if account.NetworkMapCache == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
c.holder.AddAccount(account)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
|
||||||
|
return c.holder.GetAccount(accountID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
|
||||||
|
a := c.holder.GetAccount(accountID)
|
||||||
|
if a != nil {
|
||||||
|
return a
|
||||||
|
}
|
||||||
|
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
|
||||||
|
if err != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return account
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Controller) updateAccountInHolder(account *types.Account) {
|
||||||
|
c.holder.AddAccount(account)
|
||||||
|
}
|
||||||
|
|
||||||
// GetDNSDomain returns the configured dnsDomain
|
// GetDNSDomain returns the configured dnsDomain
|
||||||
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
func (c *Controller) GetDNSDomain(settings *types.Settings) string {
|
||||||
if settings == nil {
|
if settings == nil {
|
||||||
@@ -570,7 +756,16 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
err := c.bufferSendUpdateAccountPeers(ctx, accountID)
|
peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to get peers by ids: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range peers {
|
||||||
|
c.UpdatePeerInNetworkMapCache(accountID, peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
|
||||||
}
|
}
|
||||||
@@ -580,6 +775,14 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
|
|||||||
|
|
||||||
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
|
||||||
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
|
||||||
|
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
|
||||||
|
}
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -614,6 +817,19 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
|
|||||||
MessageType: network_map.MessageTypeNetworkMap,
|
MessageType: network_map.MessageTypeNetworkMap,
|
||||||
})
|
})
|
||||||
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
c.peersUpdateManager.CloseChannel(ctx, peerID)
|
||||||
|
|
||||||
|
if c.experimentalNetworkMap(accountID) {
|
||||||
|
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
|
||||||
|
if err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
return c.bufferSendUpdateAccountPeers(ctx, accountID)
|
||||||
@@ -656,11 +872,21 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
account.InjectProxyPolicies(ctx)
|
var networkMap *types.NetworkMap
|
||||||
resourcePolicies := account.GetResourcePoliciesMap()
|
|
||||||
routers := account.GetResourceRoutersMap()
|
if c.experimentalNetworkMap(peer.AccountID) {
|
||||||
groupIDToUserIDs := account.GetActiveGroupUsers()
|
networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
|
||||||
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
} else {
|
||||||
|
account.InjectProxyPolicies(ctx)
|
||||||
|
resourcePolicies := account.GetResourcePoliciesMap()
|
||||||
|
routers := account.GetResourceRoutersMap()
|
||||||
|
groupIDToUserIDs := account.GetActiveGroupUsers()
|
||||||
|
if c.compactedNetworkMap {
|
||||||
|
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||||
|
} else {
|
||||||
|
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
|
||||||
if ok {
|
if ok {
|
||||||
|
|||||||
@@ -12,15 +12,18 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
|
||||||
|
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
|
||||||
|
|
||||||
DnsForwarderPort = nbdns.ForwarderServerPort
|
DnsForwarderPort = nbdns.ForwarderServerPort
|
||||||
OldForwarderPort = nbdns.ForwarderClientPort
|
OldForwarderPort = nbdns.ForwarderClientPort
|
||||||
DnsForwarderPortMinVersion = "v0.59.0"
|
DnsForwarderPortMinVersion = "v0.59.0"
|
||||||
)
|
)
|
||||||
|
|
||||||
type Controller interface {
|
type Controller interface {
|
||||||
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
UpdateAccountPeers(ctx context.Context, accountID string) error
|
||||||
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
|
||||||
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error
|
BufferUpdateAccountPeers(ctx context.Context, accountID string) error
|
||||||
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
|
||||||
GetDNSDomain(settings *types.Settings) string
|
GetDNSDomain(settings *types.Settings) string
|
||||||
StartWarmup(context.Context)
|
StartWarmup(context.Context)
|
||||||
|
|||||||
@@ -44,17 +44,17 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// BufferUpdateAccountPeers mocks base method.
|
// BufferUpdateAccountPeers mocks base method.
|
||||||
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason)
|
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
|
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
|
||||||
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// CountStreams mocks base method.
|
// CountStreams mocks base method.
|
||||||
@@ -238,15 +238,15 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId a
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountPeers mocks base method.
|
// UpdateAccountPeers mocks base method.
|
||||||
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error {
|
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason)
|
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
|
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
|
||||||
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason any) *gomock.Call {
|
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -62,7 +62,7 @@ func (a *MockAccountManager) GetDeletePeerCalls() int {
|
|||||||
return a.deletePeerCalls
|
return a.deletePeerCalls
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) {
|
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
|
||||||
a.mu.Lock()
|
a.mu.Lock()
|
||||||
defer a.mu.Unlock()
|
defer a.mu.Unlock()
|
||||||
if a.bufferUpdateCalls == nil {
|
if a.bufferUpdateCalls == nil {
|
||||||
@@ -248,7 +248,7 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{})
|
mockAM.BufferUpdateAccountPeers(ctx, accountID)
|
||||||
return nil
|
return nil
|
||||||
}).
|
}).
|
||||||
Times(1)
|
Times(1)
|
||||||
|
|||||||
@@ -20,7 +20,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -179,7 +178,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete})
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,9 +11,9 @@ import (
|
|||||||
|
|
||||||
// Manager defines the interface for proxy operations
|
// Manager defines the interface for proxy operations
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error)
|
Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
|
||||||
Disconnect(ctx context.Context, proxyID, sessionID string) error
|
Disconnect(ctx context.Context, proxyID string) error
|
||||||
Heartbeat(ctx context.Context, p *Proxy) error
|
Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||||
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
GetActiveClusters(ctx context.Context) ([]Cluster, error)
|
||||||
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
|
|||||||
@@ -13,8 +13,7 @@ import (
|
|||||||
// store defines the interface for proxy persistence operations
|
// store defines the interface for proxy persistence operations
|
||||||
type store interface {
|
type store interface {
|
||||||
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
SaveProxy(ctx context.Context, p *proxy.Proxy) error
|
||||||
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error
|
UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
|
||||||
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
|
|
||||||
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
|
||||||
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
|
||||||
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
|
||||||
@@ -44,7 +43,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
|
|||||||
|
|
||||||
// Connect registers a new proxy connection in the database.
|
// Connect registers a new proxy connection in the database.
|
||||||
// capabilities may be nil for old proxies that do not report them.
|
// capabilities may be nil for old proxies that do not report them.
|
||||||
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) {
|
func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
|
||||||
now := time.Now()
|
now := time.Now()
|
||||||
var caps proxy.Capabilities
|
var caps proxy.Capabilities
|
||||||
if capabilities != nil {
|
if capabilities != nil {
|
||||||
@@ -52,7 +51,6 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
|
|||||||
}
|
}
|
||||||
p := &proxy.Proxy{
|
p := &proxy.Proxy{
|
||||||
ID: proxyID,
|
ID: proxyID,
|
||||||
SessionID: sessionID,
|
|
||||||
ClusterAddress: clusterAddress,
|
ClusterAddress: clusterAddress,
|
||||||
IPAddress: ipAddress,
|
IPAddress: ipAddress,
|
||||||
LastSeen: now,
|
LastSeen: now,
|
||||||
@@ -63,42 +61,48 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
|
|||||||
|
|
||||||
if err := m.store.SaveProxy(ctx, p); err != nil {
|
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||||
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
log.WithContext(ctx).WithFields(log.Fields{
|
|
||||||
"proxyID": proxyID,
|
|
||||||
"sessionID": sessionID,
|
|
||||||
"clusterAddress": clusterAddress,
|
|
||||||
"ipAddress": ipAddress,
|
|
||||||
}).Info("proxy connected")
|
|
||||||
|
|
||||||
return p, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Disconnect marks a proxy as disconnected in the database.
|
|
||||||
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
|
||||||
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
|
|
||||||
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).WithFields(log.Fields{
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
"proxyID": proxyID,
|
"proxyID": proxyID,
|
||||||
"sessionID": sessionID,
|
"clusterAddress": clusterAddress,
|
||||||
|
"ipAddress": ipAddress,
|
||||||
|
}).Info("proxy connected")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Disconnect marks a proxy as disconnected in the database
|
||||||
|
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
|
||||||
|
now := time.Now()
|
||||||
|
p := &proxy.Proxy{
|
||||||
|
ID: proxyID,
|
||||||
|
Status: "disconnected",
|
||||||
|
DisconnectedAt: &now,
|
||||||
|
LastSeen: now,
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.store.SaveProxy(ctx, p); err != nil {
|
||||||
|
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.WithContext(ctx).WithFields(log.Fields{
|
||||||
|
"proxyID": proxyID,
|
||||||
}).Info("proxy disconnected")
|
}).Info("proxy disconnected")
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat updates the proxy's last seen timestamp.
|
// Heartbeat updates the proxy's last seen timestamp
|
||||||
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error {
|
func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||||
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil {
|
if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
|
||||||
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err)
|
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID)
|
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
|
||||||
m.metrics.IncrementProxyHeartbeatCount()
|
m.metrics.IncrementProxyHeartbeatCount()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -93,32 +93,31 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Connect mocks base method.
|
// Connect mocks base method.
|
||||||
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) {
|
func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||||
ret0, _ := ret[0].(*Proxy)
|
ret0, _ := ret[0].(error)
|
||||||
ret1, _ := ret[1].(error)
|
return ret0
|
||||||
return ret0, ret1
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Connect indicates an expected call of Connect.
|
// Connect indicates an expected call of Connect.
|
||||||
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect mocks base method.
|
// Disconnect mocks base method.
|
||||||
func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
|
func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID)
|
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disconnect indicates an expected call of Disconnect.
|
// Disconnect indicates an expected call of Disconnect.
|
||||||
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetActiveClusterAddresses mocks base method.
|
// GetActiveClusterAddresses mocks base method.
|
||||||
@@ -152,17 +151,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat mocks base method.
|
// Heartbeat mocks base method.
|
||||||
func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error {
|
func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "Heartbeat", ctx, p)
|
ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
|
||||||
ret0, _ := ret[0].(error)
|
ret0, _ := ret[0].(error)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// Heartbeat indicates an expected call of Heartbeat.
|
// Heartbeat indicates an expected call of Heartbeat.
|
||||||
func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call {
|
func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
|
||||||
}
|
}
|
||||||
|
|
||||||
// MockController is a mock of Controller interface.
|
// MockController is a mock of Controller interface.
|
||||||
|
|||||||
@@ -18,13 +18,12 @@ type Capabilities struct {
|
|||||||
// Proxy represents a reverse proxy instance
|
// Proxy represents a reverse proxy instance
|
||||||
type Proxy struct {
|
type Proxy struct {
|
||||||
ID string `gorm:"primaryKey;type:varchar(255)"`
|
ID string `gorm:"primaryKey;type:varchar(255)"`
|
||||||
SessionID string `gorm:"type:varchar(36)"`
|
|
||||||
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
|
||||||
IPAddress string `gorm:"type:varchar(45)"`
|
IPAddress string `gorm:"type:varchar(45)"`
|
||||||
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
|
||||||
ConnectedAt *time.Time
|
ConnectedAt *time.Time
|
||||||
DisconnectedAt *time.Time
|
DisconnectedAt *time.Time
|
||||||
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
|
||||||
Capabilities Capabilities `gorm:"embedded"`
|
Capabilities Capabilities `gorm:"embedded"`
|
||||||
CreatedAt time.Time
|
CreatedAt time.Time
|
||||||
UpdatedAt time.Time
|
UpdatedAt time.Time
|
||||||
|
|||||||
@@ -85,7 +85,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
|
|||||||
|
|
||||||
accountMgr := &mock_server.MockAccountManager{
|
accountMgr := &mock_server.MockAccountManager{
|
||||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -25,7 +25,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -232,7 +231,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
|
|||||||
|
|
||||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate})
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return s, nil
|
return s, nil
|
||||||
}
|
}
|
||||||
@@ -516,7 +515,7 @@ func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, s
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate})
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return service, nil
|
return service, nil
|
||||||
}
|
}
|
||||||
@@ -820,7 +819,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
|
|||||||
|
|
||||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -861,7 +860,7 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
|
|||||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -917,7 +916,7 @@ func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string
|
|||||||
|
|
||||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate})
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1099,7 +1098,7 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate})
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
serviceURL := "https://" + svc.Domain
|
serviceURL := "https://" + svc.Domain
|
||||||
if service.IsL4Protocol(svc.Mode) {
|
if service.IsL4Protocol(svc.Mode) {
|
||||||
@@ -1211,7 +1210,7 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
|
|||||||
|
|
||||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||||
|
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -1262,7 +1261,7 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
|
|||||||
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
|
||||||
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
|
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
|
||||||
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
|
||||||
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete})
|
m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -447,7 +447,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
|||||||
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
|
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
|
||||||
storedActivity = activityID.(activity.Activity)
|
storedActivity = activityID.(activity.Activity)
|
||||||
},
|
},
|
||||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockStore.EXPECT().
|
mockStore.EXPECT().
|
||||||
@@ -549,7 +549,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
|||||||
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
|
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
|
||||||
storedActivity = activityID.(activity.Activity)
|
storedActivity = activityID.(activity.Activity)
|
||||||
},
|
},
|
||||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockStore.EXPECT().
|
mockStore.EXPECT().
|
||||||
@@ -593,7 +593,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
|
|||||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) {
|
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) {
|
||||||
storedMeta = meta
|
storedMeta = meta
|
||||||
},
|
},
|
||||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||||
}
|
}
|
||||||
|
|
||||||
mockStore.EXPECT().
|
mockStore.EXPECT().
|
||||||
@@ -704,7 +704,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
|
|||||||
|
|
||||||
accountMgr := &mock_server.MockAccountManager{
|
accountMgr := &mock_server.MockAccountManager{
|
||||||
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
|
||||||
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {},
|
UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
|
||||||
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
|
||||||
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
|
||||||
},
|
},
|
||||||
@@ -1173,7 +1173,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
|
|||||||
mockAcct.EXPECT().
|
mockAcct.EXPECT().
|
||||||
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
|
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
|
||||||
mockAcct.EXPECT().
|
mockAcct.EXPECT().
|
||||||
UpdateAccountPeers(ctx, accountID, gomock.Any())
|
UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
err = mgr.DeleteService(ctx, accountID, userID, service.ID)
|
err = mgr.DeleteService(ctx, accountID, userID, service.ID)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -145,7 +144,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
|
|||||||
|
|
||||||
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
|
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
|
||||||
|
|
||||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationUpdate})
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return zone, nil
|
return zone, nil
|
||||||
}
|
}
|
||||||
@@ -207,7 +206,7 @@ func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID
|
|||||||
event()
|
event()
|
||||||
}
|
}
|
||||||
|
|
||||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationDelete})
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
"github.com/netbirdio/netbird/management/server/permissions/modules"
|
||||||
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
"github.com/netbirdio/netbird/management/server/permissions/operations"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/types"
|
|
||||||
"github.com/netbirdio/netbird/shared/management/status"
|
"github.com/netbirdio/netbird/shared/management/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -96,7 +95,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
|
|||||||
meta := record.EventMeta(zone.ID, zone.Name)
|
meta := record.EventMeta(zone.ID, zone.Name)
|
||||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
|
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
|
||||||
|
|
||||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationCreate})
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
@@ -155,7 +154,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
|
|||||||
meta := record.EventMeta(zone.ID, zone.Name)
|
meta := record.EventMeta(zone.ID, zone.Name)
|
||||||
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
|
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
|
||||||
|
|
||||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationUpdate})
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return record, nil
|
return record, nil
|
||||||
}
|
}
|
||||||
@@ -202,7 +201,7 @@ func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneI
|
|||||||
meta := record.EventMeta(zone.ID, zone.Name)
|
meta := record.EventMeta(zone.ID, zone.Name)
|
||||||
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
|
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
|
||||||
|
|
||||||
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationDelete})
|
go m.accountManager.UpdateAccountPeers(ctx, accountID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -173,7 +173,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
|
|||||||
}
|
}
|
||||||
|
|
||||||
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
|
||||||
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore())
|
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to create management server: %v", err)
|
log.Fatalf("failed to create management server: %v", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,7 +6,6 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
|
||||||
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
|
||||||
|
|
||||||
@@ -67,12 +66,6 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *BaseServer) SessionStore() *auth.SessionStore {
|
|
||||||
return Create(s, func() *auth.SessionStore {
|
|
||||||
return auth.NewSessionStore(s.CacheStore())
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *BaseServer) AuthManager() auth.Manager {
|
func (s *BaseServer) AuthManager() auth.Manager {
|
||||||
audiences := s.Config.GetAuthAudiences()
|
audiences := s.Config.GetAuthAudiences()
|
||||||
audience := s.Config.HttpConfig.AuthAudience
|
audience := s.Config.HttpConfig.AuthAudience
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user