Compare commits

..

2 Commits

Author SHA1 Message Date
Zoltán Papp
faf3559ea7 Revert "[client] Prefer systemd-resolved stub over file mode regardless of resolv.conf header (#5935)"
This reverts commit 75e408f51c.
2026-04-23 21:29:52 +02:00
Zoltán Papp
dc8c2edf50 Revert "[client] Add TTL-based refresh to mgmt DNS cache via handler chain (#5945)"
This reverts commit 801de8c68d.
2026-04-23 21:29:46 +02:00
188 changed files with 7430 additions and 11410 deletions

View File

@@ -9,7 +9,7 @@ on:
pull_request: pull_request:
env: env:
SIGN_PIPE_VER: "v0.1.4" SIGN_PIPE_VER: "v0.1.2"
GORELEASER_VER: "v2.14.3" GORELEASER_VER: "v2.14.3"
PRODUCT_NAME: "NetBird" PRODUCT_NAME: "NetBird"
COPYRIGHT: "NetBird GmbH" COPYRIGHT: "NetBird GmbH"
@@ -114,13 +114,7 @@ jobs:
retention-days: 30 retention-days: 30
release: release:
runs-on: ubuntu-24.04-8-core runs-on: ubuntu-latest-m
outputs:
release_artifact_url: ${{ steps.upload_release.outputs.artifact-url }}
linux_packages_artifact_url: ${{ steps.upload_linux_packages.outputs.artifact-url }}
windows_packages_artifact_url: ${{ steps.upload_windows_packages.outputs.artifact-url }}
macos_packages_artifact_url: ${{ steps.upload_macos_packages.outputs.artifact-url }}
ghcr_images: ${{ steps.tag_and_push_images.outputs.images_markdown }}
env: env:
flags: "" flags: ""
steps: steps:
@@ -219,13 +213,10 @@ jobs:
if: always() if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: Tag and push images (amd64 only) - name: Tag and push images (amd64 only)
id: tag_and_push_images
if: | if: |
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) || (github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
(github.event_name == 'push' && github.ref == 'refs/heads/main') (github.event_name == 'push' && github.ref == 'refs/heads/main')
run: | run: |
set -euo pipefail
resolve_tags() { resolve_tags() {
if [[ "${{ github.event_name }}" == "pull_request" ]]; then if [[ "${{ github.event_name }}" == "pull_request" ]]; then
echo "pr-${{ github.event.pull_request.number }}" echo "pr-${{ github.event.pull_request.number }}"
@@ -234,17 +225,6 @@ jobs:
fi fi
} }
ghcr_package_url() {
local image="$1" package encoded_package
package="${image#ghcr.io/}"
package="${package#*/}"
package="${package%%:*}"
encoded_package="${package//\//%2F}"
echo "https://github.com/orgs/netbirdio/packages/container/package/${encoded_package}"
}
image_refs=()
tag_and_push() { tag_and_push() {
local src="$1" img_name tag dst local src="$1" img_name tag dst
img_name="${src%%:*}" img_name="${src%%:*}"
@@ -253,56 +233,35 @@ jobs:
echo "Tagging ${src} -> ${dst}" echo "Tagging ${src} -> ${dst}"
docker tag "$src" "$dst" docker tag "$src" "$dst"
docker push "$dst" docker push "$dst"
image_refs+=("$dst")
done done
} }
cat > /tmp/goreleaser-artifacts.json <<'JSON' export -f tag_and_push resolve_tags
${{ steps.goreleaser.outputs.artifacts }}
JSON
mapfile -t src_images < <( echo '${{ steps.goreleaser.outputs.artifacts }}' | \
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name | select(startswith("ghcr.io/"))' /tmp/goreleaser-artifacts.json jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
) grep '^ghcr.io/' | while read -r SRC; do
tag_and_push "$SRC"
for src in "${src_images[@]}"; do done
tag_and_push "$src"
done
{
echo "images_markdown<<EOF"
if [[ ${#image_refs[@]} -eq 0 ]]; then
echo "_No GHCR images were pushed._"
else
printf '%s\n' "${image_refs[@]}" | sort -u | while read -r image; do
printf -- '- [`%s`](%s)\n' "$image" "$(ghcr_package_url "$image")"
done
fi
echo "EOF"
} >> "$GITHUB_OUTPUT"
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release name: release
path: dist/ path: dist/
retention-days: 7 retention-days: 7
- name: upload linux packages - name: upload linux packages
id: upload_linux_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: linux-packages name: linux-packages
path: dist/netbird_linux** path: dist/netbird_linux**
retention-days: 7 retention-days: 7
- name: upload windows packages - name: upload windows packages
id: upload_windows_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: windows-packages name: windows-packages
path: dist/netbird_windows** path: dist/netbird_windows**
retention-days: 7 retention-days: 7
- name: upload macos packages - name: upload macos packages
id: upload_macos_packages
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: macos-packages name: macos-packages
@@ -311,8 +270,6 @@ jobs:
release_ui: release_ui:
runs-on: ubuntu-latest runs-on: ubuntu-latest
outputs:
release_ui_artifact_url: ${{ steps.upload_release_ui.outputs.artifact-url }}
steps: steps:
- name: Parse semver string - name: Parse semver string
id: semver_parser id: semver_parser
@@ -403,7 +360,6 @@ jobs:
if: always() if: always()
run: rm -f /tmp/gpg-rpm-signing-key.asc run: rm -f /tmp/gpg-rpm-signing-key.asc
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release_ui
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release-ui name: release-ui
@@ -412,8 +368,6 @@ jobs:
release_ui_darwin: release_ui_darwin:
runs-on: macos-latest runs-on: macos-latest
outputs:
release_ui_darwin_artifact_url: ${{ steps.upload_release_ui_darwin.outputs.artifact-url }}
steps: steps:
- if: ${{ !startsWith(github.ref, 'refs/tags/v') }} - if: ${{ !startsWith(github.ref, 'refs/tags/v') }}
run: echo "flags=--snapshot" >> $GITHUB_ENV run: echo "flags=--snapshot" >> $GITHUB_ENV
@@ -448,258 +402,15 @@ jobs:
env: env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
- name: upload non tags for debug purposes - name: upload non tags for debug purposes
id: upload_release_ui_darwin
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v4
with: with:
name: release-ui-darwin name: release-ui-darwin
path: dist/ path: dist/
retention-days: 3 retention-days: 3
test_windows_installer:
name: "Windows Installer / Build Test"
runs-on: windows-2022
needs: [release, release_ui]
strategy:
fail-fast: false
matrix:
include:
- arch: amd64
wintun_arch: amd64
- arch: arm64
wintun_arch: arm64
defaults:
run:
shell: powershell
env:
PackageWorkdir: netbird_windows_${{ matrix.arch }}
downloadPath: '${{ github.workspace }}\temp'
steps:
- name: Parse semver string
id: semver_parser
uses: booxmedialtd/ws-action-parse-semver@v1
with:
input_string: ${{ (startsWith(github.ref, 'refs/tags/v') && github.ref) || 'refs/tags/v0.0.0' }}
version_extractor_regex: '\/v(.*)$'
- name: Checkout
uses: actions/checkout@v4
- name: Add 7-Zip to PATH
run: echo "C:\Program Files\7-Zip" | Out-File -FilePath $env:GITHUB_PATH -Encoding utf8 -Append
- name: Download release artifacts
uses: actions/download-artifact@v4
with:
name: release
path: release
- name: Download UI release artifacts
uses: actions/download-artifact@v4
with:
name: release-ui
path: release-ui
- name: Stage binaries into dist
run: |
$workdir = "dist\${{ env.PackageWorkdir }}"
New-Item -ItemType Directory -Force -Path $workdir | Out-Null
$client = Get-ChildItem -Recurse -Path release -Filter "netbird_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
$ui = Get-ChildItem -Recurse -Path release-ui -Filter "netbird-ui-windows_*_windows_${{ matrix.arch }}.tar.gz" | Select-Object -First 1
if (-not $client) { Write-Host "::error::client tarball not found for ${{ matrix.arch }}"; exit 1 }
if (-not $ui) { Write-Host "::error::ui tarball not found for ${{ matrix.arch }}"; exit 1 }
Write-Host "Client: $($client.FullName)"
Write-Host "UI: $($ui.FullName)"
tar -zvxf $client.FullName -C $workdir
tar -zvxf $ui.FullName -C $workdir
Get-ChildItem $workdir
- name: Download wintun
uses: carlosperate/download-file-action@v2
id: download-wintun
with:
file-url: https://pkgs.netbird.io/wintun/wintun-0.14.1.zip
file-name: wintun.zip
location: ${{ env.downloadPath }}
sha256: '07c256185d6ee3652e09fa55c0b673e2624b565e02c4b9091c79ca7d2f24ef51'
- name: Decompress wintun files
run: tar -zvxf "${{ steps.download-wintun.outputs.file-path }}" -C ${{ env.downloadPath }}
- name: Move wintun.dll into dist
run: mv ${{ env.downloadPath }}\wintun\bin\${{ matrix.wintun_arch }}\wintun.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download Mesa3D (amd64 only)
uses: carlosperate/download-file-action@v2
id: download-mesa3d
if: matrix.arch == 'amd64'
with:
file-url: https://downloads.fdossena.com/Projects/Mesa3D/Builds/MesaForWindows-x64-20.1.8.7z
file-name: mesa3d.7z
location: ${{ env.downloadPath }}
sha256: '71c7cb64ec229a1d6b8d62fa08e1889ed2bd17c0eeede8689daf0f25cb31d6b9'
- name: Extract Mesa3D driver (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ env.downloadPath }}" "${{ env.downloadPath }}/mesa3d.7z"
- name: Move opengl32.dll into dist (amd64 only)
if: matrix.arch == 'amd64'
run: mv ${{ env.downloadPath }}\opengl32.dll ${{ github.workspace }}\dist\${{ env.PackageWorkdir }}\
- name: Download EnVar plugin for NSIS
uses: carlosperate/download-file-action@v2
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/7/7f/EnVar_plugin.zip
file-name: envar_plugin.zip
location: ${{ github.workspace }}
- name: Extract EnVar plugin
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/envar_plugin.zip"
- name: Download ShellExecAsUser plugin for NSIS (amd64 only)
uses: carlosperate/download-file-action@v2
if: matrix.arch == 'amd64'
with:
file-url: https://nsis.sourceforge.io/mediawiki/images/6/68/ShellExecAsUser_amd64-Unicode.7z
file-name: ShellExecAsUser_amd64-Unicode.7z
location: ${{ github.workspace }}
- name: Extract ShellExecAsUser plugin (amd64 only)
if: matrix.arch == 'amd64'
run: 7z x -o"${{ github.workspace }}/NSIS_Plugins" "${{ github.workspace }}/ShellExecAsUser_amd64-Unicode.7z"
- name: Build NSIS installer
uses: joncloud/makensis-action@v3.3
with:
additional-plugin-paths: ${{ github.workspace }}/NSIS_Plugins/Plugins
script-file: client/installer.nsis
arguments: "/V4 /DARCH=${{ matrix.arch }}"
env:
APPVER: ${{ steps.semver_parser.outputs.major }}.${{ steps.semver_parser.outputs.minor }}.${{ steps.semver_parser.outputs.patch }}.${{ github.run_id }}
- name: Rename NSIS installer
run: mv netbird-installer.exe netbird_installer_test_windows_${{ matrix.arch }}.exe
- name: Install WiX
run: |
dotnet tool install --global wix --version 6.0.2
wix extension add WixToolset.Util.wixext/6.0.2
- name: Build MSI installer
env:
NETBIRD_VERSION: "${{ steps.semver_parser.outputs.fullversion }}"
run: wix build -arch ${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -ext WixToolset.Util.wixext -o netbird_installer_test_windows_${{ matrix.arch }}.msi .\client\netbird.wxs -d ProcessorArchitecture=${{ matrix.arch == 'amd64' && 'x64' || 'arm64' }} -d ArchSuffix=${{ matrix.arch }}
- name: Upload installer artifacts
if: always()
uses: actions/upload-artifact@v4
with:
name: windows-installer-test-${{ matrix.arch }}
path: |
netbird_installer_test_windows_${{ matrix.arch }}.exe
netbird_installer_test_windows_${{ matrix.arch }}.msi
retention-days: 3
comment_release_artifacts:
name: Comment release artifacts
runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin]
if: ${{ always() && github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository }}
permissions:
contents: read
issues: write
pull-requests: write
steps:
- name: Create or update PR comment
uses: actions/github-script@v7
env:
RELEASE_RESULT: ${{ needs.release.result }}
RELEASE_UI_RESULT: ${{ needs.release_ui.result }}
RELEASE_UI_DARWIN_RESULT: ${{ needs.release_ui_darwin.result }}
RELEASE_ARTIFACT_URL: ${{ needs.release.outputs.release_artifact_url }}
LINUX_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.linux_packages_artifact_url }}
WINDOWS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.windows_packages_artifact_url }}
MACOS_PACKAGES_ARTIFACT_URL: ${{ needs.release.outputs.macos_packages_artifact_url }}
RELEASE_UI_ARTIFACT_URL: ${{ needs.release_ui.outputs.release_ui_artifact_url }}
RELEASE_UI_DARWIN_ARTIFACT_URL: ${{ needs.release_ui_darwin.outputs.release_ui_darwin_artifact_url }}
GHCR_IMAGES_MARKDOWN: ${{ needs.release.outputs.ghcr_images }}
with:
github-token: ${{ secrets.GITHUB_TOKEN }}
script: |
const marker = '<!-- netbird-release-artifacts -->';
const { owner, repo } = context.repo;
const issue_number = context.payload.pull_request.number;
const runUrl = `${context.serverUrl}/${owner}/${repo}/actions/runs/${context.runId}`;
const shortSha = context.payload.pull_request.head.sha.slice(0, 7);
const artifactCell = (url, result) => {
if (url) return `[Download](${url})`;
return result && result !== 'success' ? `_Not available (${result})_` : '_Not available_';
};
const artifacts = [
['All release artifacts', process.env.RELEASE_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Linux packages', process.env.LINUX_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['Windows packages', process.env.WINDOWS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['macOS packages', process.env.MACOS_PACKAGES_ARTIFACT_URL, process.env.RELEASE_RESULT],
['UI artifacts', process.env.RELEASE_UI_ARTIFACT_URL, process.env.RELEASE_UI_RESULT],
['UI macOS artifacts', process.env.RELEASE_UI_DARWIN_ARTIFACT_URL, process.env.RELEASE_UI_DARWIN_RESULT],
];
const artifactRows = artifacts
.map(([name, url, result]) => `| ${name} | ${artifactCell(url, result)} |`)
.join('\n');
const ghcrImages = (process.env.GHCR_IMAGES_MARKDOWN || '').trim() || '_No GHCR images were pushed._';
const body = [
marker,
'## Release artifacts',
'',
`Built for PR head \`${shortSha}\` in [workflow run #${process.env.GITHUB_RUN_NUMBER}](${runUrl}).`,
'',
'| Artifact | Link |',
'| --- | --- |',
artifactRows,
'',
'### GHCR images (amd64)',
ghcrImages,
'',
'_This comment is updated by the Release workflow. Artifact links expire according to the workflow retention policy._',
].join('\n');
const comments = await github.paginate(github.rest.issues.listComments, {
owner,
repo,
issue_number,
per_page: 100,
});
const previous = comments.find(comment =>
comment.user?.type === 'Bot' && comment.body?.includes(marker)
);
if (previous) {
await github.rest.issues.updateComment({
owner,
repo,
comment_id: previous.id,
body,
});
core.info(`Updated release artifacts comment ${previous.id}`);
} else {
const { data } = await github.rest.issues.createComment({
owner,
repo,
issue_number,
body,
});
core.info(`Created release artifacts comment ${data.id}`);
}
trigger_signer: trigger_signer:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: [release, release_ui, release_ui_darwin, test_windows_installer] needs: [release, release_ui, release_ui_darwin]
if: startsWith(github.ref, 'refs/tags/') if: startsWith(github.ref, 'refs/tags/')
steps: steps:
- name: Trigger binaries sign pipelines - name: Trigger binaries sign pipelines

View File

@@ -9,8 +9,6 @@ concurrency:
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }} group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
cancel-in-progress: true cancel-in-progress: true
# Receiving workflows (cloud sync-tag, mobile bump-netbird) expect the short
# tag form (e.g. v0.30.0), not refs/tags/v0.30.0 — github.ref_name, not github.ref.
jobs: jobs:
trigger_sync_tag: trigger_sync_tag:
runs-on: ubuntu-latest runs-on: ubuntu-latest
@@ -22,30 +20,4 @@ jobs:
ref: main ref: main
repo: ${{ secrets.UPSTREAM_REPO }} repo: ${{ secrets.UPSTREAM_REPO }}
token: ${{ secrets.NC_GITHUB_TOKEN }} token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_android_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger android-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/android-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }'
trigger_ios_bump:
runs-on: ubuntu-latest
if: github.event.created && !github.event.deleted && startsWith(github.ref, 'refs/tags/v') && !contains(github.ref_name, '-')
steps:
- name: Trigger ios-client submodule bump
uses: benc-uk/workflow-dispatch@7a027648b88c2413826b6ddd6c76114894dc5ec4 # v1.3.1
with:
workflow: bump-netbird.yml
ref: main
repo: netbirdio/ios-client
token: ${{ secrets.NC_GITHUB_TOKEN }}
inputs: '{ "tag": "${{ github.ref_name }}" }' inputs: '{ "tag": "${{ github.ref_name }}" }'

View File

@@ -58,11 +58,6 @@ linters:
govet: govet:
enable: enable:
- nilness - nilness
disable:
# The inline analyzer flags x/exp/maps Clone/Clear with //go:fix inline
# directives but cannot perform the rewrite due to generic type
# parameter inference limitations in the Go inliner.
- inline
enable-all: false enable-all: false
revive: revive:
rules: rules:

View File

@@ -17,7 +17,6 @@ ENV \
NETBIRD_BIN="/usr/local/bin/netbird" \ NETBIRD_BIN="/usr/local/bin/netbird" \
NB_LOG_FILE="console,/var/log/netbird/client.log" \ NB_LOG_FILE="console,/var/log/netbird/client.log" \
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \ NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
NB_ENABLE_CAPTURE="false" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -23,7 +23,6 @@ ENV \
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \ NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
NB_LOG_FILE="console,/var/lib/netbird/client.log" \ NB_LOG_FILE="console,/var/lib/netbird/client.log" \
NB_DISABLE_DNS="true" \ NB_DISABLE_DNS="true" \
NB_ENABLE_CAPTURE="false" \
NB_ENTRYPOINT_SERVICE_TIMEOUT="30" NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ] ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]

View File

@@ -1,196 +0,0 @@
package cmd
import (
"context"
"fmt"
"io"
"os"
"os/signal"
"path/filepath"
"strings"
"syscall"
"github.com/hashicorp/go-multierror"
"github.com/spf13/cobra"
"google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
nberrors "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util/capture"
)
var captureCmd = &cobra.Command{
Use: "capture",
Short: "Capture packets on the WireGuard interface",
Long: `Captures decrypted packets flowing through the WireGuard interface.
Default output is human-readable text. Use --pcap or --output for pcap binary.
Requires --enable-capture to be set at service install or reconfigure time.
Examples:
netbird debug capture
netbird debug capture host 100.64.0.1 and port 443
netbird debug capture tcp
netbird debug capture icmp
netbird debug capture src host 10.0.0.1 and dst port 80
netbird debug capture -o capture.pcap
netbird debug capture --pcap | tshark -r -
netbird debug capture --pcap | tcpdump -r - -n`,
Args: cobra.ArbitraryArgs,
RunE: runCapture,
}
func init() {
debugCmd.AddCommand(captureCmd)
captureCmd.Flags().Bool("pcap", false, "Force pcap binary output (default when --output is set)")
captureCmd.Flags().BoolP("verbose", "v", false, "Show seq/ack, TTL, window, total length")
captureCmd.Flags().Bool("ascii", false, "Print payload as ASCII after each packet (useful for HTTP)")
captureCmd.Flags().Uint32("snap-len", 0, "Max bytes per packet (0 = full)")
captureCmd.Flags().DurationP("duration", "d", 0, "Capture duration (0 = until interrupted)")
captureCmd.Flags().StringP("output", "o", "", "Write pcap to file instead of stdout")
}
func runCapture(cmd *cobra.Command, args []string) error {
conn, err := getClient(cmd)
if err != nil {
return err
}
defer func() {
if err := conn.Close(); err != nil {
cmd.PrintErrf(errCloseConnection, err)
}
}()
client := proto.NewDaemonServiceClient(conn)
req, err := buildCaptureRequest(cmd, args)
if err != nil {
return err
}
ctx, cancel := signal.NotifyContext(cmd.Context(), syscall.SIGINT, syscall.SIGTERM)
defer cancel()
stream, err := client.StartCapture(ctx, req)
if err != nil {
return handleCaptureError(err)
}
// First Recv is the empty acceptance message from the server. If the
// device is unavailable (kernel WG, not connected, capture disabled),
// the server returns an error instead.
if _, err := stream.Recv(); err != nil {
return handleCaptureError(err)
}
out, cleanup, err := captureOutput(cmd)
if err != nil {
return err
}
if req.TextOutput {
cmd.PrintErrf("Capturing packets... Press Ctrl+C to stop.\n")
} else {
cmd.PrintErrf("Capturing packets (pcap)... Press Ctrl+C to stop.\n")
}
streamErr := streamCapture(ctx, cmd, stream, out)
cleanupErr := cleanup()
if streamErr != nil {
return streamErr
}
return cleanupErr
}
func buildCaptureRequest(cmd *cobra.Command, args []string) (*proto.StartCaptureRequest, error) {
req := &proto.StartCaptureRequest{}
if len(args) > 0 {
expr := strings.Join(args, " ")
if _, err := capture.ParseFilter(expr); err != nil {
return nil, fmt.Errorf("invalid filter: %w", err)
}
req.FilterExpr = expr
}
if snap, _ := cmd.Flags().GetUint32("snap-len"); snap > 0 {
req.SnapLen = snap
}
if d, _ := cmd.Flags().GetDuration("duration"); d != 0 {
if d < 0 {
return nil, fmt.Errorf("duration must not be negative")
}
req.Duration = durationpb.New(d)
}
req.Verbose, _ = cmd.Flags().GetBool("verbose")
req.Ascii, _ = cmd.Flags().GetBool("ascii")
outPath, _ := cmd.Flags().GetString("output")
forcePcap, _ := cmd.Flags().GetBool("pcap")
req.TextOutput = !forcePcap && outPath == ""
return req, nil
}
func streamCapture(ctx context.Context, cmd *cobra.Command, stream proto.DaemonService_StartCaptureClient, out io.Writer) error {
for {
pkt, err := stream.Recv()
if err != nil {
if ctx.Err() != nil {
cmd.PrintErrf("\nCapture stopped.\n")
return nil //nolint:nilerr // user interrupted
}
if err == io.EOF {
cmd.PrintErrf("\nCapture finished.\n")
return nil
}
return handleCaptureError(err)
}
if _, err := out.Write(pkt.GetData()); err != nil {
return fmt.Errorf("write output: %w", err)
}
}
}
// captureOutput returns the writer for capture data and a cleanup function
// that finalizes the file. Errors from the cleanup must be propagated.
func captureOutput(cmd *cobra.Command) (io.Writer, func() error, error) {
outPath, _ := cmd.Flags().GetString("output")
if outPath == "" {
return os.Stdout, func() error { return nil }, nil
}
f, err := os.CreateTemp(filepath.Dir(outPath), filepath.Base(outPath)+".*.tmp")
if err != nil {
return nil, nil, fmt.Errorf("create output file: %w", err)
}
tmpPath := f.Name()
return f, func() error {
var merr *multierror.Error
if err := f.Close(); err != nil {
merr = multierror.Append(merr, fmt.Errorf("close output file: %w", err))
}
fi, statErr := os.Stat(tmpPath)
if statErr != nil || fi.Size() == 0 {
if rmErr := os.Remove(tmpPath); rmErr != nil && !os.IsNotExist(rmErr) {
merr = multierror.Append(merr, fmt.Errorf("remove empty output file: %w", rmErr))
}
return nberrors.FormatErrorOrNil(merr)
}
if err := os.Rename(tmpPath, outPath); err != nil {
merr = multierror.Append(merr, fmt.Errorf("rename output file: %w", err))
return nberrors.FormatErrorOrNil(merr)
}
cmd.PrintErrf("Wrote %s\n", outPath)
return nberrors.FormatErrorOrNil(merr)
}, nil
}
func handleCaptureError(err error) error {
if s, ok := status.FromError(err); ok {
return fmt.Errorf("%s", s.Message())
}
return err
}

View File

@@ -9,7 +9,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"google.golang.org/grpc/status" "google.golang.org/grpc/status"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/debug" "github.com/netbirdio/netbird/client/internal/debug"
@@ -240,50 +239,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
}() }()
} }
captureStarted := false
if wantCapture, _ := cmd.Flags().GetBool("capture"); wantCapture {
captureTimeout := duration + 30*time.Second
const maxBundleCapture = 10 * time.Minute
if captureTimeout > maxBundleCapture {
captureTimeout = maxBundleCapture
}
_, err := client.StartBundleCapture(cmd.Context(), &proto.StartBundleCaptureRequest{
Timeout: durationpb.New(captureTimeout),
})
if err != nil {
cmd.PrintErrf("Failed to start packet capture: %v\n", status.Convert(err).Message())
} else {
captureStarted = true
cmd.Println("Packet capture started.")
// Safety: always stop on exit, even if the normal stop below runs too.
defer func() {
if captureStarted {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
}
}
}()
}
}
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil { if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
return waitErr return waitErr
} }
cmd.Println("\nDuration completed") cmd.Println("\nDuration completed")
if captureStarted {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := client.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
cmd.PrintErrf("Failed to stop packet capture: %v\n", err)
} else {
captureStarted = false
cmd.Println("Packet capture stopped.")
}
}
if cpuProfilingStarted { if cpuProfilingStarted {
if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil { if _, err := client.StopCPUProfile(cmd.Context(), &proto.StopCPUProfileRequest{}); err != nil {
cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err) cmd.PrintErrf("Failed to stop CPU profiling: %v\n", err)
@@ -456,5 +416,4 @@ func init() {
forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle") forCmd.Flags().BoolVarP(&systemInfoFlag, "system-info", "S", true, "Adds system information to the debug bundle")
forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server") forCmd.Flags().BoolVarP(&uploadBundleFlag, "upload-bundle", "U", false, "Uploads the debug bundle to a server")
forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle") forCmd.Flags().StringVar(&uploadBundleURLFlag, "upload-bundle-url", types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
forCmd.Flags().Bool("capture", false, "Capture packets during the debug duration and include in bundle")
} }

View File

@@ -10,7 +10,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/spf13/cobra" "github.com/spf13/cobra"
"golang.org/x/term"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
gstatus "google.golang.org/grpc/status" gstatus "google.golang.org/grpc/status"
@@ -24,7 +23,6 @@ import (
func init() { func init() {
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
loginCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) loginCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location") loginCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) Netbird config file location")
} }
@@ -258,7 +256,7 @@ func doForegroundLogin(ctx context.Context, cmd *cobra.Command, setupKey string,
} }
func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error { func handleSSOLogin(ctx context.Context, cmd *cobra.Command, loginResp *proto.LoginResponse, client proto.DaemonServiceClient, pm *profilemanager.ProfileManager) error {
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser, showQR) openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName}) resp, err := client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
if err != nil { if err != nil {
@@ -326,7 +324,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err) return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
} }
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser, showQR) openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo) tokenInfo, err := oAuthFlow.WaitToken(context.TODO(), flowInfo)
if err != nil { if err != nil {
@@ -336,7 +334,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *pro
return &tokenInfo, nil return &tokenInfo, nil
} }
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser, showQR bool) { func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
var codeMsg string var codeMsg string
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) { if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode) codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
@@ -350,12 +348,6 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
verificationURIComplete + " " + codeMsg) verificationURIComplete + " " + codeMsg)
} }
if showQR {
if f, ok := cmd.OutOrStdout().(*os.File); ok && term.IsTerminal(int(f.Fd())) {
printQRCode(f, verificationURIComplete)
}
}
cmd.Println("") cmd.Println("")
if !noBrowser { if !noBrowser {

View File

@@ -1,25 +0,0 @@
package cmd
import (
"io"
"github.com/mdp/qrterminal/v3"
)
// printQRCode prints a QR code for the given URL to the writer.
// Called only when the user explicitly requests QR output via --qr.
func printQRCode(w io.Writer, url string) {
if url == "" {
return
}
qrterminal.GenerateWithConfig(url, qrterminal.Config{
Level: qrterminal.M,
Writer: w,
HalfBlocks: true,
BlackChar: qrterminal.BLACK_BLACK,
WhiteChar: qrterminal.WHITE_WHITE,
BlackWhiteChar: qrterminal.BLACK_WHITE,
WhiteBlackChar: qrterminal.WHITE_BLACK,
QuietZone: qrterminal.QUIET_ZONE,
})
}

View File

@@ -1,26 +0,0 @@
package cmd
import (
"bytes"
"testing"
)
func TestPrintQRCode_EmptyURL(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "")
if buf.Len() != 0 {
t.Error("expected no output for empty URL")
}
}
func TestPrintQRCode_WritesOutput(t *testing.T) {
var buf bytes.Buffer
printQRCode(&buf, "https://example.com/auth")
if buf.Len() == 0 {
t.Error("expected QR code output for non-empty URL")
}
}

View File

@@ -75,7 +75,6 @@ var (
mtu uint16 mtu uint16
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool updateSettingsDisabled bool
captureEnabled bool
networksDisabled bool networksDisabled bool
rootCmd = &cobra.Command{ rootCmd = &cobra.Command{

View File

@@ -44,7 +44,6 @@ func init() {
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd) serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles") serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings") serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
serviceCmd.PersistentFlags().BoolVar(&captureEnabled, "enable-capture", false, "Enables packet capture via 'netbird debug capture'. To persist, use: netbird service install --enable-capture")
serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks") serviceCmd.PersistentFlags().BoolVar(&networksDisabled, "disable-networks", false, "Disables network selection. If enabled, the client will not allow listing, selecting, or deselecting networks. To persist, use: netbird service install --disable-networks")
rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name") rootCmd.PersistentFlags().StringVarP(&serviceName, "service", "s", defaultServiceName, "Netbird system service name")

View File

@@ -61,7 +61,7 @@ func (p *program) Start(svc service.Service) error {
} }
} }
serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, captureEnabled, networksDisabled) serverInstance := server.New(p.ctx, util.FindFirstLogPath(logFiles), configPath, profilesDisabled, updateSettingsDisabled, networksDisabled)
if err := serverInstance.Start(); err != nil { if err := serverInstance.Start(); err != nil {
log.Fatalf("failed to start daemon: %v", err) log.Fatalf("failed to start daemon: %v", err)
} }

View File

@@ -59,10 +59,6 @@ func buildServiceArguments() []string {
args = append(args, "--disable-update-settings") args = append(args, "--disable-update-settings")
} }
if captureEnabled {
args = append(args, "--enable-capture")
}
if networksDisabled { if networksDisabled {
args = append(args, "--disable-networks") args = append(args, "--disable-networks")
} }

View File

@@ -28,7 +28,6 @@ type serviceParams struct {
LogFiles []string `json:"log_files,omitempty"` LogFiles []string `json:"log_files,omitempty"`
DisableProfiles bool `json:"disable_profiles,omitempty"` DisableProfiles bool `json:"disable_profiles,omitempty"`
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"` DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
EnableCapture bool `json:"enable_capture,omitempty"`
DisableNetworks bool `json:"disable_networks,omitempty"` DisableNetworks bool `json:"disable_networks,omitempty"`
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"` ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
} }
@@ -80,7 +79,6 @@ func currentServiceParams() *serviceParams {
LogFiles: logFiles, LogFiles: logFiles,
DisableProfiles: profilesDisabled, DisableProfiles: profilesDisabled,
DisableUpdateSettings: updateSettingsDisabled, DisableUpdateSettings: updateSettingsDisabled,
EnableCapture: captureEnabled,
DisableNetworks: networksDisabled, DisableNetworks: networksDisabled,
} }
@@ -146,10 +144,6 @@ func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
updateSettingsDisabled = params.DisableUpdateSettings updateSettingsDisabled = params.DisableUpdateSettings
} }
if !serviceCmd.PersistentFlags().Changed("enable-capture") {
captureEnabled = params.EnableCapture
}
if !serviceCmd.PersistentFlags().Changed("disable-networks") { if !serviceCmd.PersistentFlags().Changed("disable-networks") {
networksDisabled = params.DisableNetworks networksDisabled = params.DisableNetworks
} }

View File

@@ -535,7 +535,6 @@ func fieldToGlobalVar(field string) string {
"LogFiles": "logFiles", "LogFiles": "logFiles",
"DisableProfiles": "profilesDisabled", "DisableProfiles": "profilesDisabled",
"DisableUpdateSettings": "updateSettingsDisabled", "DisableUpdateSettings": "updateSettingsDisabled",
"EnableCapture": "captureEnabled",
"DisableNetworks": "networksDisabled", "DisableNetworks": "networksDisabled",
"ServiceEnvVars": "serviceEnvVars", "ServiceEnvVars": "serviceEnvVars",
} }

View File

@@ -135,7 +135,7 @@ func startManagement(t *testing.T, config *config.Config, testFile string) (*grp
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil, nil) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &mgmt.MockIntegratedValidator{}, networkMapController, nil)
if err != nil { if err != nil {
t.Fatal(err) t.Fatal(err)
} }
@@ -160,7 +160,7 @@ func startClientDaemon(
s := grpc.NewServer() s := grpc.NewServer()
server := client.New(ctx, server := client.New(ctx,
"", "", false, false, false, false) "", "", false, false, false)
if err := server.Start(); err != nil { if err := server.Start(); err != nil {
t.Fatal(err) t.Fatal(err)
} }

View File

@@ -39,9 +39,6 @@ const (
noBrowserFlag = "no-browser" noBrowserFlag = "no-browser"
noBrowserDesc = "do not open the browser for SSO login" noBrowserDesc = "do not open the browser for SSO login"
showQRFlag = "qr"
showQRDesc = "show QR code for the SSO login URL (useful for headless machines without browser access)"
profileNameFlag = "profile" profileNameFlag = "profile"
profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used." profileNameDesc = "profile name to use for the login. If not specified, the last used profile will be used."
) )
@@ -51,7 +48,6 @@ var (
dnsLabels []string dnsLabels []string
dnsLabelsValidated domain.List dnsLabelsValidated domain.List
noBrowser bool noBrowser bool
showQR bool
profileName string profileName string
configPath string configPath string
@@ -84,7 +80,6 @@ func init() {
) )
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc) upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
upCmd.PersistentFlags().BoolVar(&showQR, showQRFlag, false, showQRDesc)
upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc) upCmd.PersistentFlags().StringVar(&profileName, profileNameFlag, "", profileNameDesc)
upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ") upCmd.PersistentFlags().StringVarP(&configPath, "config", "c", "", "(DEPRECATED) NetBird config file location. ")

View File

@@ -1,65 +0,0 @@
package embed
import (
"io"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/util/capture"
)
// CaptureOptions configures a packet capture session.
type CaptureOptions struct {
// Output receives pcap-formatted data. Nil disables pcap output.
Output io.Writer
// TextOutput receives human-readable packet summaries. Nil disables text output.
TextOutput io.Writer
// Filter is a BPF-like filter expression (e.g. "host 10.0.0.1 and tcp port 443").
// Empty captures all packets.
Filter string
// Verbose adds seq/ack, TTL, window, and total length to text output.
Verbose bool
// ASCII dumps transport payload as printable ASCII after each packet line.
ASCII bool
}
// CaptureStats reports capture session counters.
type CaptureStats struct {
Packets int64
Bytes int64
Dropped int64
}
// CaptureSession represents an active packet capture. Call Stop to end the
// capture and flush buffered packets.
type CaptureSession struct {
sess *capture.Session
engine *internal.Engine
}
// Stop ends the capture, flushes remaining packets, and detaches from the device.
// Safe to call multiple times.
func (cs *CaptureSession) Stop() {
if cs.engine != nil {
_ = cs.engine.SetCapture(nil)
cs.engine = nil
}
if cs.sess != nil {
cs.sess.Stop()
}
}
// Stats returns current capture counters.
func (cs *CaptureSession) Stats() CaptureStats {
s := cs.sess.Stats()
return CaptureStats{
Packets: s.Packets,
Bytes: s.Bytes,
Dropped: s.Dropped,
}
}
// Done returns a channel that is closed when the capture's writer goroutine
// has fully exited and all buffered packets have been flushed.
func (cs *CaptureSession) Done() <-chan struct{} {
return cs.sess.Done()
}

View File

@@ -24,7 +24,6 @@ import (
"github.com/netbirdio/netbird/client/system" "github.com/netbirdio/netbird/client/system"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
"github.com/netbirdio/netbird/util/capture"
) )
var ( var (
@@ -66,7 +65,7 @@ type Options struct {
PrivateKey string PrivateKey string
// ManagementURL overrides the default management server URL // ManagementURL overrides the default management server URL
ManagementURL string ManagementURL string
// PreSharedKey is the pre-shared key for the tunnel interface // PreSharedKey is the pre-shared key for the WireGuard interface
PreSharedKey string PreSharedKey string
// LogOutput is the output destination for logs (defaults to os.Stderr if nil) // LogOutput is the output destination for logs (defaults to os.Stderr if nil)
LogOutput io.Writer LogOutput io.Writer
@@ -82,9 +81,9 @@ type Options struct {
DisableClientRoutes bool DisableClientRoutes bool
// BlockInbound blocks all inbound connections from peers // BlockInbound blocks all inbound connections from peers
BlockInbound bool BlockInbound bool
// WireguardPort is the port for the tunnel interface. Use 0 for a random port. // WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
WireguardPort *int WireguardPort *int
// MTU is the MTU for the tunnel interface. // MTU is the MTU for the WireGuard interface.
// Valid values are in the range 576..8192 bytes. // Valid values are in the range 576..8192 bytes.
// If non-nil, this value overrides any value stored in the config file. // If non-nil, this value overrides any value stored in the config file.
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280. // If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
@@ -470,52 +469,6 @@ func (c *Client) VerifySSHHostKey(peerAddress string, key []byte) error {
return sshcommon.VerifyHostKey(storedKey, key, peerAddress) return sshcommon.VerifyHostKey(storedKey, key, peerAddress)
} }
// StartCapture begins capturing packets on this client's tunnel device.
// Only one capture can be active at a time; starting a new one stops the previous.
// Call StopCapture (or CaptureSession.Stop) to end it.
func (c *Client) StartCapture(opts CaptureOptions) (*CaptureSession, error) {
engine, err := c.getEngine()
if err != nil {
return nil, err
}
var matcher capture.Matcher
if opts.Filter != "" {
m, err := capture.ParseFilter(opts.Filter)
if err != nil {
return nil, fmt.Errorf("parse filter: %w", err)
}
matcher = m
}
sess, err := capture.NewSession(capture.Options{
Output: opts.Output,
TextOutput: opts.TextOutput,
Matcher: matcher,
Verbose: opts.Verbose,
ASCII: opts.ASCII,
})
if err != nil {
return nil, fmt.Errorf("create capture session: %w", err)
}
if err := engine.SetCapture(sess); err != nil {
sess.Stop()
return nil, fmt.Errorf("set capture: %w", err)
}
return &CaptureSession{sess: sess, engine: engine}, nil
}
// StopCapture stops the active capture session if one is running.
func (c *Client) StopCapture() error {
engine, err := c.getEngine()
if err != nil {
return err
}
return engine.SetCapture(nil)
}
// getEngine safely retrieves the engine from the client with proper locking. // getEngine safely retrieves the engine from the client with proper locking.
// Returns ErrClientNotStarted if the client is not started. // Returns ErrClientNotStarted if the client is not started.
// Returns ErrEngineNotStarted if the engine is not available. // Returns ErrEngineNotStarted if the engine is not available.

View File

@@ -115,13 +115,12 @@ type Manager struct {
localipmanager *localIPManager localipmanager *localIPManager
udpTracker *conntrack.UDPTracker udpTracker *conntrack.UDPTracker
icmpTracker *conntrack.ICMPTracker icmpTracker *conntrack.ICMPTracker
tcpTracker *conntrack.TCPTracker tcpTracker *conntrack.TCPTracker
forwarder atomic.Pointer[forwarder.Forwarder] forwarder atomic.Pointer[forwarder.Forwarder]
pendingCapture atomic.Pointer[forwarder.PacketCapture] logger *nblog.Logger
logger *nblog.Logger flowLogger nftypes.FlowLogger
flowLogger nftypes.FlowLogger
blockRule firewall.Rule blockRule firewall.Rule
@@ -352,19 +351,6 @@ func (m *Manager) determineRouting() error {
return nil return nil
} }
// SetPacketCapture sets or clears packet capture on the forwarder endpoint.
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
func (m *Manager) SetPacketCapture(pc forwarder.PacketCapture) {
if pc == nil {
m.pendingCapture.Store(nil)
} else {
m.pendingCapture.Store(&pc)
}
if fwder := m.forwarder.Load(); fwder != nil {
fwder.SetCapture(pc)
}
}
// initForwarder initializes the forwarder, it disables routing on errors // initForwarder initializes the forwarder, it disables routing on errors
func (m *Manager) initForwarder() error { func (m *Manager) initForwarder() error {
if m.forwarder.Load() != nil { if m.forwarder.Load() != nil {
@@ -386,11 +372,6 @@ func (m *Manager) initForwarder() error {
m.forwarder.Store(forwarder) m.forwarder.Store(forwarder)
// Re-load after store: a concurrent SetPacketCapture may have seen forwarder as nil and only updated pendingCapture.
if pc := m.pendingCapture.Load(); pc != nil {
forwarder.SetCapture(*pc)
}
log.Debug("forwarder initialized") log.Debug("forwarder initialized")
return nil return nil
@@ -633,7 +614,6 @@ func (m *Manager) resetState() {
} }
if fwder := m.forwarder.Load(); fwder != nil { if fwder := m.forwarder.Load(); fwder != nil {
fwder.SetCapture(nil)
fwder.Stop() fwder.Stop()
} }

View File

@@ -12,19 +12,12 @@ import (
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log" nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
) )
// PacketCapture captures raw packets for debugging. Implementations must be
// safe for concurrent use and must not block.
type PacketCapture interface {
Offer(data []byte, outbound bool)
}
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device // endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
type endpoint struct { type endpoint struct {
logger *nblog.Logger logger *nblog.Logger
dispatcher stack.NetworkDispatcher dispatcher stack.NetworkDispatcher
device *wgdevice.Device device *wgdevice.Device
mtu atomic.Uint32 mtu atomic.Uint32
capture atomic.Pointer[PacketCapture]
} }
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) { func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
@@ -61,17 +54,13 @@ func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error)
continue continue
} }
pktBytes := data.AsSlice() // Send the packet through WireGuard
address := netHeader.DestinationAddress() address := netHeader.DestinationAddress()
if err := e.device.CreateOutboundPacket(pktBytes, address.AsSlice()); err != nil { err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
if err != nil {
e.logger.Error1("CreateOutboundPacket: %v", err) e.logger.Error1("CreateOutboundPacket: %v", err)
continue continue
} }
if pc := e.capture.Load(); pc != nil {
(*pc).Offer(pktBytes, true)
}
written++ written++
} }

View File

@@ -139,16 +139,6 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
return f, nil return f, nil
} }
// SetCapture sets or clears the packet capture on the forwarder endpoint.
// This captures outbound packets that bypass the FilteredDevice (netstack forwarding).
func (f *Forwarder) SetCapture(pc PacketCapture) {
if pc == nil {
f.endpoint.capture.Store(nil)
return
}
f.endpoint.capture.Store(&pc)
}
func (f *Forwarder) InjectIncomingPacket(payload []byte) error { func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
if len(payload) < header.IPv4MinimumSize { if len(payload) < header.IPv4MinimumSize {
return fmt.Errorf("packet too small: %d bytes", len(payload)) return fmt.Errorf("packet too small: %d bytes", len(payload))

View File

@@ -270,9 +270,5 @@ func (f *Forwarder) injectICMPReply(id stack.TransportEndpointID, icmpPayload []
return 0 return 0
} }
if pc := f.endpoint.capture.Load(); pc != nil {
(*pc).Offer(fullPacket, true)
}
return len(fullPacket) return len(fullPacket)
} }

View File

@@ -3,7 +3,6 @@ package device
import ( import (
"net/netip" "net/netip"
"sync" "sync"
"sync/atomic"
"golang.zx2c4.com/wireguard/tun" "golang.zx2c4.com/wireguard/tun"
) )
@@ -29,20 +28,11 @@ type PacketFilter interface {
SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool) SetTCPPacketHook(ip netip.Addr, dPort uint16, hook func(packet []byte) bool)
} }
// PacketCapture captures raw packets for debugging. Implementations must be
// safe for concurrent use and must not block.
type PacketCapture interface {
// Offer submits a packet for capture. outbound is true for packets
// leaving the host (Read path), false for packets arriving (Write path).
Offer(data []byte, outbound bool)
}
// FilteredDevice to override Read or Write of packets // FilteredDevice to override Read or Write of packets
type FilteredDevice struct { type FilteredDevice struct {
tun.Device tun.Device
filter PacketFilter filter PacketFilter
capture atomic.Pointer[PacketCapture]
mutex sync.RWMutex mutex sync.RWMutex
closeOnce sync.Once closeOnce sync.Once
} }
@@ -73,25 +63,20 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
if n, err = d.Device.Read(bufs, sizes, offset); err != nil { if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
return 0, err return 0, err
} }
d.mutex.RLock() d.mutex.RLock()
filter := d.filter filter := d.filter
d.mutex.RUnlock() d.mutex.RUnlock()
if filter != nil { if filter == nil {
for i := 0; i < n; i++ { return
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
}
}
} }
if pc := d.capture.Load(); pc != nil { for i := 0; i < n; i++ {
for i := 0; i < n; i++ { if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
(*pc).Offer(bufs[i][offset:offset+sizes[i]], true) bufs = append(bufs[:i], bufs[i+1:]...)
sizes = append(sizes[:i], sizes[i+1:]...)
n--
i--
} }
} }
@@ -100,13 +85,6 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
// Write wraps write method with filtering feature // Write wraps write method with filtering feature
func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) { func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
// Capture before filtering so dropped packets are still visible in captures.
if pc := d.capture.Load(); pc != nil {
for _, buf := range bufs {
(*pc).Offer(buf[offset:], false)
}
}
d.mutex.RLock() d.mutex.RLock()
filter := d.filter filter := d.filter
d.mutex.RUnlock() d.mutex.RUnlock()
@@ -118,10 +96,9 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
filteredBufs := make([][]byte, 0, len(bufs)) filteredBufs := make([][]byte, 0, len(bufs))
dropped := 0 dropped := 0
for _, buf := range bufs { for _, buf := range bufs {
if filter.FilterInbound(buf[offset:], len(buf)) { if !filter.FilterInbound(buf[offset:], len(buf)) {
dropped++
} else {
filteredBufs = append(filteredBufs, buf) filteredBufs = append(filteredBufs, buf)
dropped++
} }
} }
@@ -136,14 +113,3 @@ func (d *FilteredDevice) SetFilter(filter PacketFilter) {
d.filter = filter d.filter = filter
d.mutex.Unlock() d.mutex.Unlock()
} }
// SetCapture sets or clears the packet capture sink. Pass nil to disable.
// Uses atomic store so the hot path (Read/Write) is a single pointer load
// with no locking overhead when capture is off.
func (d *FilteredDevice) SetCapture(pc PacketCapture) {
if pc == nil {
d.capture.Store(nil)
return
}
d.capture.Store(&pc)
}

View File

@@ -158,7 +158,7 @@ func TestDeviceWrapperRead(t *testing.T) {
t.Errorf("unexpected error: %v", err) t.Errorf("unexpected error: %v", err)
return return
} }
if n != 1 { if n != 0 {
t.Errorf("expected n=1, got %d", n) t.Errorf("expected n=1, got %d", n)
return return
} }

View File

@@ -201,18 +201,7 @@ Pop $0
Function .onInit Function .onInit
StrCpy $INSTDIR "${INSTALL_DIR}" StrCpy $INSTDIR "${INSTALL_DIR}"
; Default autostart to enabled so silent installs (/S) match the interactive default
StrCpy $AutostartEnabled "1"
; Pre-0.70.1 installers ran without SetRegView, so their uninstall keys live
; in the 32-bit view. Fall back to it so upgrades still find them.
SetRegView 64
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString" ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
${If} $R0 == ""
SetRegView 32
ReadRegStr $R0 HKLM "Software\Microsoft\Windows\CurrentVersion\Uninstall\$(^NAME)" "UninstallString"
SetRegView 64
${EndIf}
${If} $R0 != "" ${If} $R0 != ""
# if silent install jump to uninstall step # if silent install jump to uninstall step
IfSilent uninstall IfSilent uninstall
@@ -225,10 +214,6 @@ ${If} $R0 != ""
${EndIf} ${EndIf}
FunctionEnd FunctionEnd
Function un.onInit
SetRegView 64
FunctionEnd
###################################################################### ######################################################################
Section -MainProgram Section -MainProgram
${INSTALL_TYPE} ${INSTALL_TYPE}
@@ -243,7 +228,6 @@ Section -MainProgram
!else !else
File /r "..\\dist\\netbird_windows_amd64\\" File /r "..\\dist\\netbird_windows_amd64\\"
!endif !endif
File "..\\client\\ui\\assets\\netbird.png"
SectionEnd SectionEnd
###################################################################### ######################################################################
@@ -263,11 +247,9 @@ WriteRegStr ${REG_ROOT} "${UI_REG_APP_PATH}" "" "$INSTDIR\${UI_APP_EXE}"
; Create autostart registry entry based on checkbox ; Create autostart registry entry based on checkbox
DetailPrint "Autostart enabled: $AutostartEnabled" DetailPrint "Autostart enabled: $AutostartEnabled"
${If} $AutostartEnabled == "1" ${If} $AutostartEnabled == "1"
WriteRegStr HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}" '"$INSTDIR\${UI_APP_EXE}.exe"' WriteRegStr HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" "$INSTDIR\${UI_APP_EXE}.exe"
DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe" DetailPrint "Added autostart registry entry: $INSTDIR\${UI_APP_EXE}.exe"
${Else} ${Else}
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
DetailPrint "Autostart not enabled by user" DetailPrint "Autostart not enabled by user"
${EndIf} ${EndIf}
@@ -301,8 +283,6 @@ ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
; Remove autostart registry entry ; Remove autostart registry entry
DetailPrint "Removing autostart registry entry if exists..." DetailPrint "Removing autostart registry entry if exists..."
DeleteRegValue HKLM "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Legacy: pre-HKLM installs wrote to HKCU; clean that up too.
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}" DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
; Handle data deletion based on checkbox ; Handle data deletion based on checkbox
@@ -341,7 +321,6 @@ DetailPrint "Removing registry keys..."
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}" DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}" DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
DeleteRegKey HKCU "Software\Classes\AppUserModelId\${APP_NAME}"
DetailPrint "Removing application directory from PATH..." DetailPrint "Removing application directory from PATH..."
EnVar::SetHKLM EnVar::SetHKLM

View File

@@ -333,10 +333,6 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
c.statusRecorder.MarkSignalConnected() c.statusRecorder.MarkSignalConnected()
relayURLs, token := parseRelayInfo(loginResp) relayURLs, token := parseRelayInfo(loginResp)
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
relayURLs = override
}
peerConfig := loginResp.GetPeerConfig() peerConfig := loginResp.GetPeerConfig()
engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath) engineConfig, err := createEngineConfig(myPrivateKey, c.config, peerConfig, logPath)

View File

@@ -61,7 +61,6 @@ allocs.prof: Allocations profiling information.
threadcreate.prof: Thread creation profiling information. threadcreate.prof: Thread creation profiling information.
cpu.prof: CPU profiling information. cpu.prof: CPU profiling information.
stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation. stack_trace.txt: Complete stack traces of all goroutines at the time of bundle creation.
capture.pcap: Packet capture in pcap format. Only present when capture was running during bundle collection. Omitted from anonymized bundles because it contains raw decrypted packet data.
Anonymization Process Anonymization Process
@@ -235,7 +234,6 @@ type BundleGenerator struct {
logPath string logPath string
tempDir string tempDir string
cpuProfile []byte cpuProfile []byte
capturePath string
refreshStatus func() // Optional callback to refresh status before bundle generation refreshStatus func() // Optional callback to refresh status before bundle generation
clientMetrics MetricsExporter clientMetrics MetricsExporter
@@ -259,8 +257,7 @@ type GeneratorDependencies struct {
LogPath string LogPath string
TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used. TempDir string // Directory for temporary bundle zip files. If empty, os.TempDir() is used.
CPUProfile []byte CPUProfile []byte
CapturePath string RefreshStatus func() // Optional callback to refresh status before bundle generation
RefreshStatus func()
ClientMetrics MetricsExporter ClientMetrics MetricsExporter
} }
@@ -280,7 +277,6 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
logPath: deps.LogPath, logPath: deps.LogPath,
tempDir: deps.TempDir, tempDir: deps.TempDir,
cpuProfile: deps.CPUProfile, cpuProfile: deps.CPUProfile,
capturePath: deps.CapturePath,
refreshStatus: deps.RefreshStatus, refreshStatus: deps.RefreshStatus,
clientMetrics: deps.ClientMetrics, clientMetrics: deps.ClientMetrics,
@@ -350,10 +346,6 @@ func (g *BundleGenerator) createArchive() error {
log.Errorf("failed to add CPU profile to debug bundle: %v", err) log.Errorf("failed to add CPU profile to debug bundle: %v", err)
} }
if err := g.addCaptureFile(); err != nil {
log.Errorf("failed to add capture file to debug bundle: %v", err)
}
if err := g.addStackTrace(); err != nil { if err := g.addStackTrace(); err != nil {
log.Errorf("failed to add stack trace to debug bundle: %v", err) log.Errorf("failed to add stack trace to debug bundle: %v", err)
} }
@@ -607,12 +599,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
if g.internalConfig.EnableSSHRemotePortForwarding != nil { if g.internalConfig.EnableSSHRemotePortForwarding != nil {
configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding)) configContent.WriteString(fmt.Sprintf("EnableSSHRemotePortForwarding: %v\n", *g.internalConfig.EnableSSHRemotePortForwarding))
} }
if g.internalConfig.DisableSSHAuth != nil {
configContent.WriteString(fmt.Sprintf("DisableSSHAuth: %v\n", *g.internalConfig.DisableSSHAuth))
}
if g.internalConfig.SSHJWTCacheTTL != nil {
configContent.WriteString(fmt.Sprintf("SSHJWTCacheTTL: %d\n", *g.internalConfig.SSHJWTCacheTTL))
}
configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes)) configContent.WriteString(fmt.Sprintf("DisableClientRoutes: %v\n", g.internalConfig.DisableClientRoutes))
configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes)) configContent.WriteString(fmt.Sprintf("DisableServerRoutes: %v\n", g.internalConfig.DisableServerRoutes))
@@ -639,7 +625,6 @@ func (g *BundleGenerator) addCommonConfigFields(configContent *strings.Builder)
} }
configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled)) configContent.WriteString(fmt.Sprintf("LazyConnectionEnabled: %v\n", g.internalConfig.LazyConnectionEnabled))
configContent.WriteString(fmt.Sprintf("MTU: %d\n", g.internalConfig.MTU))
} }
func (g *BundleGenerator) addProf() (err error) { func (g *BundleGenerator) addProf() (err error) {
@@ -684,29 +669,6 @@ func (g *BundleGenerator) addCPUProfile() error {
return nil return nil
} }
func (g *BundleGenerator) addCaptureFile() error {
if g.capturePath == "" {
return nil
}
if g.anonymize {
log.Info("skipping capture file in anonymized bundle (contains raw packet data)")
return nil
}
f, err := os.Open(g.capturePath)
if err != nil {
return fmt.Errorf("open capture file: %w", err)
}
defer f.Close()
if err := g.addFileToZip(f, "capture.pcap"); err != nil {
return fmt.Errorf("add capture file to zip: %w", err)
}
return nil
}
func (g *BundleGenerator) addStackTrace() error { func (g *BundleGenerator) addStackTrace() error {
buf := make([]byte, 5242880) // 5 MB buffer buf := make([]byte, 5242880) // 5 MB buffer
n := runtime.Stack(buf, true) n := runtime.Stack(buf, true)

View File

@@ -5,21 +5,16 @@ import (
"bytes" "bytes"
"encoding/json" "encoding/json"
"net" "net"
"net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect"
"strings" "strings"
"testing" "testing"
"time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require" "github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/anonymize" "github.com/netbirdio/netbird/client/anonymize"
"github.com/netbirdio/netbird/client/configs" "github.com/netbirdio/netbird/client/configs"
"github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/shared/management/domain"
mgmProto "github.com/netbirdio/netbird/shared/management/proto" mgmProto "github.com/netbirdio/netbird/shared/management/proto"
) )
@@ -476,8 +471,8 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false, anonymize: false,
input: map[string]any{ input: map[string]any{
jsonKeyServiceEnv: map[string]any{ jsonKeyServiceEnv: map[string]any{
"HOME": "/root", "HOME": "/root",
"PATH": "/usr/bin", "PATH": "/usr/bin",
"NB_LOG_LEVEL": "debug", "NB_LOG_LEVEL": "debug",
}, },
}, },
@@ -494,9 +489,9 @@ func TestSanitizeServiceEnvVars(t *testing.T) {
anonymize: false, anonymize: false,
input: map[string]any{ input: map[string]any{
jsonKeyServiceEnv: map[string]any{ jsonKeyServiceEnv: map[string]any{
"NB_SETUP_KEY": "abc123", "NB_SETUP_KEY": "abc123",
"NB_API_TOKEN": "tok_xyz", "NB_API_TOKEN": "tok_xyz",
"NB_LOG_LEVEL": "info", "NB_LOG_LEVEL": "info",
}, },
}, },
check: func(t *testing.T, params map[string]any) { check: func(t *testing.T, params map[string]any) {
@@ -771,127 +766,3 @@ Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
assert.Contains(t, anonNftables, "chain input {") assert.Contains(t, anonNftables, "chain input {")
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;") assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
} }
// TestAddConfig_AllFieldsCovered uses reflection to ensure every field in
// profilemanager.Config is either rendered in the debug bundle or explicitly
// excluded. When a new field is added to Config, this test fails until the
// developer either dumps it in addConfig/addCommonConfigFields or adds it to
// the excluded set with a justification.
func TestAddConfig_AllFieldsCovered(t *testing.T) {
excluded := map[string]string{
"PrivateKey": "sensitive: WireGuard private key",
"PreSharedKey": "sensitive: WireGuard pre-shared key",
"SSHKey": "sensitive: SSH private key",
"ClientCertKeyPair": "non-config: parsed cert pair, not serialized",
}
mURL, _ := url.Parse("https://api.example.com:443")
aURL, _ := url.Parse("https://admin.example.com:443")
bTrue := true
iVal := 42
cfg := &profilemanager.Config{
PrivateKey: "priv",
PreSharedKey: "psk",
ManagementURL: mURL,
AdminURL: aURL,
WgIface: "wt0",
WgPort: 51820,
NetworkMonitor: &bTrue,
IFaceBlackList: []string{"eth0"},
DisableIPv6Discovery: true,
RosenpassEnabled: true,
RosenpassPermissive: true,
ServerSSHAllowed: &bTrue,
EnableSSHRoot: &bTrue,
EnableSSHSFTP: &bTrue,
EnableSSHLocalPortForwarding: &bTrue,
EnableSSHRemotePortForwarding: &bTrue,
DisableSSHAuth: &bTrue,
SSHJWTCacheTTL: &iVal,
DisableClientRoutes: true,
DisableServerRoutes: true,
DisableDNS: true,
DisableFirewall: true,
BlockLANAccess: true,
BlockInbound: true,
DisableNotifications: &bTrue,
DNSLabels: domain.List{},
SSHKey: "sshkey",
NATExternalIPs: []string{"1.2.3.4"},
CustomDNSAddress: "1.1.1.1:53",
DisableAutoConnect: true,
DNSRouteInterval: 5 * time.Second,
ClientCertPath: "/tmp/cert",
ClientCertKeyPath: "/tmp/key",
LazyConnectionEnabled: true,
MTU: 1280,
}
for _, anonymize := range []bool{false, true} {
t.Run("anonymize="+map[bool]string{true: "true", false: "false"}[anonymize], func(t *testing.T) {
g := &BundleGenerator{
anonymizer: newAnonymizerForTest(),
internalConfig: cfg,
anonymize: anonymize,
}
var sb strings.Builder
g.addCommonConfigFields(&sb)
rendered := sb.String() + renderAddConfigSpecific(g)
val := reflect.ValueOf(cfg).Elem()
typ := val.Type()
var missing []string
for i := 0; i < typ.NumField(); i++ {
name := typ.Field(i).Name
if _, ok := excluded[name]; ok {
continue
}
if !strings.Contains(rendered, name+":") {
missing = append(missing, name)
}
}
if len(missing) > 0 {
t.Fatalf("Config field(s) not present in debug bundle output: %v\n"+
"Either render the field in addCommonConfigFields/addConfig, "+
"or add it to the excluded map with a justification.", missing)
}
})
}
}
// renderAddConfigSpecific renders the fields handled by the anonymize/non-anonymize
// branches in addConfig (ManagementURL, AdminURL, NATExternalIPs, CustomDNSAddress).
// addCommonConfigFields covers the rest. Keeping this in the test mirrors the
// production shape without needing to write an actual zip.
func renderAddConfigSpecific(g *BundleGenerator) string {
var sb strings.Builder
if g.anonymize {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.ManagementURL.String()) + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.anonymizer.AnonymizeURI(g.internalConfig.AdminURL.String()) + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.anonymizer.AnonymizeString(g.internalConfig.CustomDNSAddress) + "\n")
}
} else {
if g.internalConfig.ManagementURL != nil {
sb.WriteString("ManagementURL: " + g.internalConfig.ManagementURL.String() + "\n")
}
if g.internalConfig.AdminURL != nil {
sb.WriteString("AdminURL: " + g.internalConfig.AdminURL.String() + "\n")
}
sb.WriteString("NATExternalIPs: x\n")
if g.internalConfig.CustomDNSAddress != "" {
sb.WriteString("CustomDNSAddress: " + g.internalConfig.CustomDNSAddress + "\n")
}
}
return sb.String()
}
func newAnonymizerForTest() *anonymize.Anonymizer {
return anonymize.NewAnonymizer(anonymize.DefaultAddresses())
}

View File

@@ -13,7 +13,6 @@ import (
const ( const (
defaultResolvConfPath = "/etc/resolv.conf" defaultResolvConfPath = "/etc/resolv.conf"
nsswitchConfPath = "/etc/nsswitch.conf"
) )
type resolvConf struct { type resolvConf struct {

View File

@@ -1,10 +1,7 @@
package dns package dns
import ( import (
"context"
"fmt" "fmt"
"math"
"net"
"slices" "slices"
"strconv" "strconv"
"strings" "strings"
@@ -195,12 +192,6 @@ func (c *HandlerChain) logHandlers() {
} }
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
c.dispatch(w, r, math.MaxInt)
}
// dispatch routes a DNS request through the chain, skipping handlers with
// priority > maxPriority. Shared by ServeDNS and ResolveInternal.
func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority int) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
return return
} }
@@ -225,9 +216,6 @@ func (c *HandlerChain) dispatch(w dns.ResponseWriter, r *dns.Msg, maxPriority in
// Try handlers in priority order // Try handlers in priority order
for _, entry := range handlers { for _, entry := range handlers {
if entry.Priority > maxPriority {
continue
}
if !c.isHandlerMatch(qname, entry) { if !c.isHandlerMatch(qname, entry) {
continue continue
} }
@@ -285,55 +273,6 @@ func (c *HandlerChain) logResponse(logger *log.Entry, cw *ResponseWriterChain, q
cw.response.Len(), meta, time.Since(startTime)) cw.response.Len(), meta, time.Since(startTime))
} }
// ResolveInternal runs an in-process DNS query against the chain, skipping any
// handler with priority > maxPriority. Used by internal callers (e.g. the mgmt
// cache refresher) that must bypass themselves to avoid loops. Honors ctx
// cancellation; on ctx.Done the dispatch goroutine is left to drain on its own
// (bounded by the invoked handler's internal timeout).
func (c *HandlerChain) ResolveInternal(ctx context.Context, r *dns.Msg, maxPriority int) (*dns.Msg, error) {
if len(r.Question) == 0 {
return nil, fmt.Errorf("empty question")
}
base := &internalResponseWriter{}
done := make(chan struct{})
go func() {
c.dispatch(base, r, maxPriority)
close(done)
}()
select {
case <-done:
case <-ctx.Done():
// Prefer a completed response if dispatch finished concurrently with cancellation.
select {
case <-done:
default:
return nil, fmt.Errorf("resolve %s: %w", strings.ToLower(r.Question[0].Name), ctx.Err())
}
}
if base.response == nil || base.response.Rcode == dns.RcodeRefused {
return nil, fmt.Errorf("no handler resolved %s at priority ≤ %d",
strings.ToLower(r.Question[0].Name), maxPriority)
}
return base.response, nil
}
// HasRootHandlerAtOrBelow reports whether any "." handler is registered at
// priority ≤ maxPriority.
func (c *HandlerChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
c.mu.RLock()
defer c.mu.RUnlock()
for _, h := range c.handlers {
if h.Pattern == "." && h.Priority <= maxPriority {
return true
}
}
return false
}
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool { func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
switch { switch {
case entry.Pattern == ".": case entry.Pattern == ".":
@@ -352,36 +291,3 @@ func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
} }
} }
} }
// internalResponseWriter captures a dns.Msg for in-process chain queries.
type internalResponseWriter struct {
response *dns.Msg
}
func (w *internalResponseWriter) WriteMsg(m *dns.Msg) error { w.response = m; return nil }
func (w *internalResponseWriter) LocalAddr() net.Addr { return nil }
func (w *internalResponseWriter) RemoteAddr() net.Addr { return nil }
// Write unpacks raw DNS bytes so handlers that call Write instead of WriteMsg
// still surface their answer to ResolveInternal.
func (w *internalResponseWriter) Write(p []byte) (int, error) {
msg := new(dns.Msg)
if err := msg.Unpack(p); err != nil {
return 0, err
}
w.response = msg
return len(p), nil
}
func (w *internalResponseWriter) Close() error { return nil }
func (w *internalResponseWriter) TsigStatus() error { return nil }
// TsigTimersOnly is part of dns.ResponseWriter.
func (w *internalResponseWriter) TsigTimersOnly(bool) {
// no-op: in-process queries carry no TSIG state.
}
// Hijack is part of dns.ResponseWriter.
func (w *internalResponseWriter) Hijack() {
// no-op: in-process queries have no underlying connection to hand off.
}

View File

@@ -1,15 +1,11 @@
package dns_test package dns_test
import ( import (
"context"
"net"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/mock" "github.com/stretchr/testify/mock"
"github.com/stretchr/testify/require"
nbdns "github.com/netbirdio/netbird/client/internal/dns" nbdns "github.com/netbirdio/netbird/client/internal/dns"
"github.com/netbirdio/netbird/client/internal/dns/test" "github.com/netbirdio/netbird/client/internal/dns/test"
@@ -1046,163 +1042,3 @@ func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
}) })
} }
} }
// answeringHandler writes a fixed A record to ack the query. Used to verify
// which handler ResolveInternal dispatches to.
type answeringHandler struct {
name string
ip string
}
func (h *answeringHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
_ = w.WriteMsg(resp)
}
func (h *answeringHandler) String() string { return h.name }
func TestHandlerChain_ResolveInternal_SkipsAboveMaxPriority(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
low := &answeringHandler{name: "low", ip: "10.0.0.2"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
chain.AddHandler("example.com.", low, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
assert.NotNil(t, resp)
assert.Equal(t, 1, len(resp.Answer))
a, ok := resp.Answer[0].(*dns.A)
assert.True(t, ok)
assert.Equal(t, "10.0.0.2", a.A.String(), "should skip mgmtCache handler and resolve via upstream")
}
func TestHandlerChain_ResolveInternal_ErrorWhenNoMatch(t *testing.T) {
chain := nbdns.NewHandlerChain()
high := &answeringHandler{name: "high", ip: "10.0.0.1"}
chain.AddHandler("example.com.", high, nbdns.PriorityMgmtCache)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
_, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.Error(t, err, "no handler at or below maxPriority should error")
}
// rawWriteHandler packs a response and calls ResponseWriter.Write directly
// (instead of WriteMsg), exercising the internalResponseWriter.Write path.
type rawWriteHandler struct {
ip string
}
func (h *rawWriteHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
resp := &dns.Msg{}
resp.SetReply(r)
resp.Answer = []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: r.Question[0].Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP(h.ip).To4(),
}}
packed, err := resp.Pack()
if err != nil {
return
}
_, _ = w.Write(packed)
}
func TestHandlerChain_ResolveInternal_CapturesRawWrite(t *testing.T) {
chain := nbdns.NewHandlerChain()
chain.AddHandler("example.com.", &rawWriteHandler{ip: "10.0.0.3"}, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
resp, err := chain.ResolveInternal(context.Background(), r, nbdns.PriorityUpstream)
assert.NoError(t, err)
require.NotNil(t, resp)
require.Len(t, resp.Answer, 1)
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok)
assert.Equal(t, "10.0.0.3", a.A.String(), "handlers calling Write(packed) must still surface their answer")
}
func TestHandlerChain_ResolveInternal_EmptyQuestion(t *testing.T) {
chain := nbdns.NewHandlerChain()
_, err := chain.ResolveInternal(context.Background(), new(dns.Msg), nbdns.PriorityUpstream)
assert.Error(t, err)
}
// hangingHandler blocks indefinitely until closed, simulating a wedged upstream.
type hangingHandler struct {
block chan struct{}
}
func (h *hangingHandler) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
<-h.block
resp := &dns.Msg{}
resp.SetReply(r)
_ = w.WriteMsg(resp)
}
func (h *hangingHandler) String() string { return "hangingHandler" }
func TestHandlerChain_ResolveInternal_HonorsContextTimeout(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &hangingHandler{block: make(chan struct{})}
defer close(h.block)
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
r := new(dns.Msg)
r.SetQuestion("example.com.", dns.TypeA)
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
defer cancel()
start := time.Now()
_, err := chain.ResolveInternal(ctx, r, nbdns.PriorityUpstream)
elapsed := time.Since(start)
assert.Error(t, err)
assert.ErrorIs(t, err, context.DeadlineExceeded)
assert.Less(t, elapsed, 500*time.Millisecond, "ResolveInternal must return shortly after ctx deadline")
}
func TestHandlerChain_HasRootHandlerAtOrBelow(t *testing.T) {
chain := nbdns.NewHandlerChain()
h := &answeringHandler{name: "h", ip: "10.0.0.1"}
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "empty chain")
chain.AddHandler("example.com.", h, nbdns.PriorityUpstream)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "non-root handler does not count")
chain.AddHandler(".", h, nbdns.PriorityMgmtCache)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler above threshold excluded")
chain.AddHandler(".", h, nbdns.PriorityDefault)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root handler at PriorityDefault included")
chain.RemoveHandler(".", nbdns.PriorityDefault)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
// Primary nsgroup case: root handler lands at PriorityUpstream.
chain.AddHandler(".", h, nbdns.PriorityUpstream)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityUpstream included")
chain.RemoveHandler(".", nbdns.PriorityUpstream)
// Fallback case: original /etc/resolv.conf entries land at PriorityFallback.
chain.AddHandler(".", h, nbdns.PriorityFallback)
assert.True(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream), "root at PriorityFallback included")
chain.RemoveHandler(".", nbdns.PriorityFallback)
assert.False(t, chain.HasRootHandlerAtOrBelow(nbdns.PriorityUpstream))
}

View File

@@ -46,12 +46,12 @@ type restoreHostManager interface {
} }
func newHostManager(wgInterface string) (hostManager, error) { func newHostManager(wgInterface string) (hostManager, error) {
osManager, reason, err := getOSDNSManagerType() osManager, err := getOSDNSManagerType()
if err != nil { if err != nil {
return nil, fmt.Errorf("get os dns manager type: %w", err) return nil, fmt.Errorf("get os dns manager type: %w", err)
} }
log.Infof("System DNS manager discovered: %s (%s)", osManager, reason) log.Infof("System DNS manager discovered: %s", osManager)
mgr, err := newHostManagerFromType(wgInterface, osManager) mgr, err := newHostManagerFromType(wgInterface, osManager)
// need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value // need to explicitly return nil mgr on error to avoid returning a non-nil interface containing a nil value
if err != nil { if err != nil {
@@ -74,49 +74,17 @@ func newHostManagerFromType(wgInterface string, osManager osManagerType) (restor
} }
} }
func getOSDNSManagerType() (osManagerType, string, error) { func getOSDNSManagerType() (osManagerType, error) {
resolved := isSystemdResolvedRunning()
nss := isLibnssResolveUsed()
stub := checkStub()
// Prefer systemd-resolved whenever it owns libc resolution, regardless of
// who wrote /etc/resolv.conf. File-mode rewrites do not affect lookups
// that go through nss-resolve, and in foreign mode they can loop back
// through resolved as an upstream.
if resolved && (nss || stub) {
return systemdManager, fmt.Sprintf("systemd-resolved active (nss-resolve=%t, stub=%t)", nss, stub), nil
}
mgr, reason, rejected, err := scanResolvConfHeader()
if err != nil {
return 0, "", err
}
if reason != "" {
return mgr, reason, nil
}
fallback := fmt.Sprintf("no manager matched (resolved=%t, nss-resolve=%t, stub=%t)", resolved, nss, stub)
if len(rejected) > 0 {
fallback += "; rejected: " + strings.Join(rejected, ", ")
}
return fileManager, fallback, nil
}
// scanResolvConfHeader walks /etc/resolv.conf header comments and returns the
// matching manager. If reason is empty the caller should pick file mode and
// use rejected for diagnostics.
func scanResolvConfHeader() (osManagerType, string, []string, error) {
file, err := os.Open(defaultResolvConfPath) file, err := os.Open(defaultResolvConfPath)
if err != nil { if err != nil {
return 0, "", nil, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err) return 0, fmt.Errorf("unable to open %s for checking owner, got error: %w", defaultResolvConfPath, err)
} }
defer func() { defer func() {
if cerr := file.Close(); cerr != nil { if err := file.Close(); err != nil {
log.Errorf("close file %s: %s", defaultResolvConfPath, cerr) log.Errorf("close file %s: %s", defaultResolvConfPath, err)
} }
}() }()
var rejected []string
scanner := bufio.NewScanner(file) scanner := bufio.NewScanner(file)
for scanner.Scan() { for scanner.Scan() {
text := scanner.Text() text := scanner.Text()
@@ -124,48 +92,41 @@ func scanResolvConfHeader() (osManagerType, string, []string, error) {
continue continue
} }
if text[0] != '#' { if text[0] != '#' {
break return fileManager, nil
} }
if mgr, reason, rej := matchResolvConfHeader(text); reason != "" { if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return mgr, reason, nil, nil return netbirdManager, nil
} else if rej != "" { }
rejected = append(rejected, rej) if strings.Contains(text, "NetworkManager") && isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, nil
}
if strings.Contains(text, "systemd-resolved") && isSystemdResolvedRunning() {
if checkStub() {
return systemdManager, nil
} else {
return fileManager, nil
}
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, nil
}
return resolvConfManager, nil
} }
} }
if err := scanner.Err(); err != nil && err != io.EOF { if err := scanner.Err(); err != nil && err != io.EOF {
return 0, "", nil, fmt.Errorf("scan: %w", err) return 0, fmt.Errorf("scan: %w", err)
} }
return 0, "", rejected, nil
return fileManager, nil
} }
// matchResolvConfHeader inspects a single comment line. Returns either a // checkStub checks if the stub resolver is disabled in systemd-resolved. If it is disabled, we fall back to file manager.
// definitive (manager, reason) or a non-empty rejected diagnostic.
func matchResolvConfHeader(text string) (osManagerType, string, string) {
if strings.Contains(text, fileGeneratedResolvConfContentHeader) {
return netbirdManager, "netbird-managed resolv.conf header detected", ""
}
if strings.Contains(text, "NetworkManager") {
if isDbusListenerRunning(networkManagerDest, networkManagerDbusObjectNode) && isNetworkManagerSupported() {
return networkManager, "NetworkManager header + supported version on dbus", ""
}
return 0, "", "NetworkManager header (no dbus or unsupported version)"
}
if strings.Contains(text, "resolvconf") {
if isSystemdResolveConfMode() {
return systemdManager, "resolvconf header in systemd-resolved compatibility mode", ""
}
return resolvConfManager, "resolvconf header detected", ""
}
return 0, "", ""
}
// checkStub reports whether systemd-resolved's stub (127.0.0.53) is listed
// in /etc/resolv.conf. On parse failure we assume it is, to avoid dropping
// into file mode while resolved is active.
func checkStub() bool { func checkStub() bool {
rConf, err := parseDefaultResolvConf() rConf, err := parseDefaultResolvConf()
if err != nil { if err != nil {
log.Warnf("failed to parse resolv conf, assuming stub is active: %s", err) log.Warnf("failed to parse resolv conf: %s", err)
return true return true
} }
@@ -178,36 +139,3 @@ func checkStub() bool {
return false return false
} }
// isLibnssResolveUsed reports whether nss-resolve is listed before dns on
// the hosts: line of /etc/nsswitch.conf. When it is, libc lookups are
// delegated to systemd-resolved regardless of /etc/resolv.conf.
func isLibnssResolveUsed() bool {
bs, err := os.ReadFile(nsswitchConfPath)
if err != nil {
log.Debugf("read %s: %v", nsswitchConfPath, err)
return false
}
return parseNsswitchResolveAhead(bs)
}
func parseNsswitchResolveAhead(data []byte) bool {
for _, line := range strings.Split(string(data), "\n") {
if i := strings.IndexByte(line, '#'); i >= 0 {
line = line[:i]
}
fields := strings.Fields(line)
if len(fields) < 2 || fields[0] != "hosts:" {
continue
}
for _, module := range fields[1:] {
switch module {
case "dns":
return false
case "resolve":
return true
}
}
}
return false
}

View File

@@ -1,76 +0,0 @@
//go:build (linux && !android) || freebsd
package dns
import "testing"
func TestParseNsswitchResolveAhead(t *testing.T) {
tests := []struct {
name string
in string
want bool
}{
{
name: "resolve before dns with action token",
in: "hosts: mymachines resolve [!UNAVAIL=return] files myhostname dns\n",
want: true,
},
{
name: "dns before resolve",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns resolve\n",
want: false,
},
{
name: "debian default with only dns",
in: "hosts: files mdns4_minimal [NOTFOUND=return] dns mymachines\n",
want: false,
},
{
name: "neither resolve nor dns",
in: "hosts: files myhostname\n",
want: false,
},
{
name: "no hosts line",
in: "passwd: files systemd\ngroup: files systemd\n",
want: false,
},
{
name: "empty",
in: "",
want: false,
},
{
name: "comments and blank lines ignored",
in: "# comment\n\n# another\nhosts: resolve dns\n",
want: true,
},
{
name: "trailing inline comment",
in: "hosts: resolve [!UNAVAIL=return] dns # fallback\n",
want: true,
},
{
name: "hosts token must be the first field",
in: " hosts: resolve dns\n",
want: true,
},
{
name: "other db line mentioning resolve is ignored",
in: "networks: resolve\nhosts: dns\n",
want: false,
},
{
name: "only resolve, no dns",
in: "hosts: files resolve\n",
want: true,
},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
if got := parseNsswitchResolveAhead([]byte(tt.in)); got != tt.want {
t.Errorf("parseNsswitchResolveAhead() = %v, want %v", got, tt.want)
}
})
}
}

View File

@@ -2,83 +2,40 @@ package mgmt
import ( import (
"context" "context"
"errors"
"fmt" "fmt"
"net" "net"
"net/netip"
"net/url" "net/url"
"os"
"slices"
"strings" "strings"
"sync" "sync"
"sync/atomic"
"time" "time"
"github.com/miekg/dns" "github.com/miekg/dns"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/sync/singleflight"
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config" dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
"github.com/netbirdio/netbird/client/internal/dns/resutil"
"github.com/netbirdio/netbird/shared/management/domain" "github.com/netbirdio/netbird/shared/management/domain"
) )
const ( const dnsTimeout = 5 * time.Second
dnsTimeout = 5 * time.Second
defaultTTL = 300 * time.Second
refreshBackoff = 30 * time.Second
// envMgmtCacheTTL overrides defaultTTL for integration/dev testing. // Resolver caches critical NetBird infrastructure domains
envMgmtCacheTTL = "NB_MGMT_CACHE_TTL"
)
// ChainResolver lets the cache refresh stale entries through the DNS handler
// chain instead of net.DefaultResolver, avoiding loopback when NetBird is the
// system resolver.
type ChainResolver interface {
ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error)
HasRootHandlerAtOrBelow(maxPriority int) bool
}
// cachedRecord holds DNS records plus timestamps used for TTL refresh.
// records and cachedAt are set at construction and treated as immutable;
// lastFailedRefresh and consecFailures are mutable and must be accessed under
// Resolver.mutex.
type cachedRecord struct {
records []dns.RR
cachedAt time.Time
lastFailedRefresh time.Time
consecFailures int
}
// Resolver caches critical NetBird infrastructure domains.
// records, refreshing, mgmtDomain and serverDomains are all guarded by mutex.
type Resolver struct { type Resolver struct {
records map[dns.Question]*cachedRecord records map[dns.Question][]dns.RR
mgmtDomain *domain.Domain mgmtDomain *domain.Domain
serverDomains *dnsconfig.ServerDomains serverDomains *dnsconfig.ServerDomains
mutex sync.RWMutex mutex sync.RWMutex
}
chain ChainResolver type ipsResponse struct {
chainMaxPriority int ips []netip.Addr
refreshGroup singleflight.Group err error
// refreshing tracks questions whose refresh is running via the OS
// fallback path. A ServeDNS hit for a question in this map indicates
// the OS resolver routed the recursive query back to us (loop). Only
// the OS path arms this so chain-path refreshes don't produce false
// positives. The atomic bool is CAS-flipped once per refresh to
// throttle the warning log.
refreshing map[dns.Question]*atomic.Bool
cacheTTL time.Duration
} }
// NewResolver creates a new management domains cache resolver. // NewResolver creates a new management domains cache resolver.
func NewResolver() *Resolver { func NewResolver() *Resolver {
return &Resolver{ return &Resolver{
records: make(map[dns.Question]*cachedRecord), records: make(map[dns.Question][]dns.RR),
refreshing: make(map[dns.Question]*atomic.Bool),
cacheTTL: resolveCacheTTL(),
} }
} }
@@ -87,19 +44,7 @@ func (m *Resolver) String() string {
return "MgmtCacheResolver" return "MgmtCacheResolver"
} }
// SetChainResolver wires the handler chain used to refresh stale cache entries. // ServeDNS implements dns.Handler interface.
// maxPriority caps which handlers may answer refresh queries (typically
// PriorityUpstream, so upstream/default/fallback handlers are consulted and
// mgmt/route/local handlers are skipped).
func (m *Resolver) SetChainResolver(chain ChainResolver, maxPriority int) {
m.mutex.Lock()
m.chain = chain
m.chainMaxPriority = maxPriority
m.mutex.Unlock()
}
// ServeDNS serves cached A/AAAA records. Stale entries are returned
// immediately and refreshed asynchronously (stale-while-revalidate).
func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) { func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
if len(r.Question) == 0 { if len(r.Question) == 0 {
m.continueToNext(w, r) m.continueToNext(w, r)
@@ -115,14 +60,7 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
} }
m.mutex.RLock() m.mutex.RLock()
cached, found := m.records[question] records, found := m.records[question]
inflight := m.refreshing[question]
var shouldRefresh bool
if found {
stale := time.Since(cached.cachedAt) > m.cacheTTL
inBackoff := !cached.lastFailedRefresh.IsZero() && time.Since(cached.lastFailedRefresh) < refreshBackoff
shouldRefresh = stale && !inBackoff
}
m.mutex.RUnlock() m.mutex.RUnlock()
if !found { if !found {
@@ -130,23 +68,12 @@ func (m *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
return return
} }
if inflight != nil && inflight.CompareAndSwap(false, true) {
log.Warnf("mgmt cache: possible resolver loop for domain=%s: served stale while an OS-fallback refresh was inflight (if NetBird is the system resolver, the OS-path predicate is wrong)",
question.Name)
}
// Skip scheduling a refresh goroutine if one is already inflight for
// this question; singleflight would dedup anyway but skipping avoids
// a parked goroutine per stale hit under bursty load.
if shouldRefresh && inflight == nil {
m.scheduleRefresh(question, cached)
}
resp := &dns.Msg{} resp := &dns.Msg{}
resp.SetReply(r) resp.SetReply(r)
resp.Authoritative = false resp.Authoritative = false
resp.RecursionAvailable = true resp.RecursionAvailable = true
resp.Answer = cloneRecordsWithTTL(cached.records, m.responseTTL(cached.cachedAt))
resp.Answer = append(resp.Answer, records...)
log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name) log.Debugf("serving %d cached records for domain=%s", len(resp.Answer), question.Name)
@@ -171,260 +98,101 @@ func (m *Resolver) continueToNext(w dns.ResponseWriter, r *dns.Msg) {
} }
} }
// AddDomain resolves a domain and stores its A/AAAA records in the cache. // AddDomain manually adds a domain to cache by resolving it.
// A family that resolves NODATA (nil err, zero records) evicts any stale
// entry for that qtype.
func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error { func (m *Resolver) AddDomain(ctx context.Context, d domain.Domain) error {
dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString())) dnsName := strings.ToLower(dns.Fqdn(d.PunycodeString()))
ctx, cancel := context.WithTimeout(ctx, dnsTimeout) ctx, cancel := context.WithTimeout(ctx, dnsTimeout)
defer cancel() defer cancel()
aRecords, aaaaRecords, errA, errAAAA := m.lookupBoth(ctx, d, dnsName) ips, err := lookupIPWithExtraTimeout(ctx, d)
if err != nil {
if errA != nil && errAAAA != nil { return err
return fmt.Errorf("resolve %s: %w", d.SafeString(), errors.Join(errA, errAAAA))
} }
if len(aRecords) == 0 && len(aaaaRecords) == 0 { var aRecords, aaaaRecords []dns.RR
if err := errors.Join(errA, errAAAA); err != nil { for _, ip := range ips {
return fmt.Errorf("resolve %s: no A/AAAA records: %w", d.SafeString(), err) if ip.Is4() {
rr := &dns.A{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeA,
Class: dns.ClassINET,
Ttl: 300,
},
A: ip.AsSlice(),
}
aRecords = append(aRecords, rr)
} else if ip.Is6() {
rr := &dns.AAAA{
Hdr: dns.RR_Header{
Name: dnsName,
Rrtype: dns.TypeAAAA,
Class: dns.ClassINET,
Ttl: 300,
},
AAAA: ip.AsSlice(),
}
aaaaRecords = append(aaaaRecords, rr)
} }
return fmt.Errorf("resolve %s: no A/AAAA records", d.SafeString())
} }
now := time.Now()
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock()
m.applyFamilyRecords(dnsName, dns.TypeA, aRecords, errA, now) if len(aRecords) > 0 {
m.applyFamilyRecords(dnsName, dns.TypeAAAA, aaaaRecords, errAAAA, now) aQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeA,
Qclass: dns.ClassINET,
}
m.records[aQuestion] = aRecords
}
log.Debugf("added/updated domain=%s with %d A records and %d AAAA records", if len(aaaaRecords) > 0 {
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
m.records[aaaaQuestion] = aaaaRecords
}
m.mutex.Unlock()
log.Debugf("added domain=%s with %d A records and %d AAAA records",
d.SafeString(), len(aRecords), len(aaaaRecords)) d.SafeString(), len(aRecords), len(aaaaRecords))
return nil return nil
} }
// applyFamilyRecords writes records, evicts on NODATA, leaves the cache func lookupIPWithExtraTimeout(ctx context.Context, d domain.Domain) ([]netip.Addr, error) {
// untouched on error. Caller holds m.mutex. log.Infof("looking up IP for mgmt domain=%s", d.SafeString())
func (m *Resolver) applyFamilyRecords(dnsName string, qtype uint16, records []dns.RR, err error, now time.Time) { defer log.Infof("done looking up IP for mgmt domain=%s", d.SafeString())
q := dns.Question{Name: dnsName, Qtype: qtype, Qclass: dns.ClassINET} resultChan := make(chan *ipsResponse, 1)
switch {
case len(records) > 0:
m.records[q] = &cachedRecord{records: records, cachedAt: now}
case err == nil:
delete(m.records, q)
}
}
// scheduleRefresh kicks off an async refresh. DoChan spawns one goroutine per go func() {
// unique in-flight key; bursty stale hits share its channel. expected is the ips, err := net.DefaultResolver.LookupNetIP(ctx, "ip", d.PunycodeString())
// cachedRecord pointer observed by the caller; the refresh only mutates the resultChan <- &ipsResponse{
// cache if that pointer is still the one stored, so a stale in-flight refresh err: err,
// can't clobber a newer entry written by AddDomain or a competing refresh. ips: ips,
func (m *Resolver) scheduleRefresh(question dns.Question, expected *cachedRecord) {
key := question.Name + "|" + dns.TypeToString[question.Qtype]
_ = m.refreshGroup.DoChan(key, func() (any, error) {
return nil, m.refreshQuestion(question, expected)
})
}
// refreshQuestion replaces the cached records on success, or marks the entry
// failed (arming the backoff) on failure. While this runs, ServeDNS can detect
// a resolver loop by spotting a query for this same question arriving on us.
// expected pins the cache entry observed at schedule time; mutations only apply
// if m.records[question] still points at it.
func (m *Resolver) refreshQuestion(question dns.Question, expected *cachedRecord) error {
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
defer cancel()
d, err := domain.FromString(strings.TrimSuffix(question.Name, "."))
if err != nil {
m.markRefreshFailed(question, expected)
return fmt.Errorf("parse domain: %w", err)
}
records, err := m.lookupRecords(ctx, d, question)
if err != nil {
fails := m.markRefreshFailed(question, expected)
logf := log.Warnf
if fails == 0 || fails > 1 {
logf = log.Debugf
} }
logf("refresh mgmt cache domain=%s type=%s: %v (consecutive failures=%d)", }()
d.SafeString(), dns.TypeToString[question.Qtype], err, fails)
return err var resp *ipsResponse
select {
case <-time.After(dnsTimeout + time.Millisecond*500):
log.Warnf("timed out waiting for IP for mgmt domain=%s", d.SafeString())
return nil, fmt.Errorf("timed out waiting for ips to be available for domain %s", d.SafeString())
case <-ctx.Done():
return nil, ctx.Err()
case resp = <-resultChan:
} }
// NOERROR/NODATA: family gone upstream, evict so we stop serving stale. if resp.err != nil {
if len(records) == 0 { return nil, fmt.Errorf("resolve domain %s: %w", d.SafeString(), resp.err)
m.mutex.Lock()
if m.records[question] == expected {
delete(m.records, question)
m.mutex.Unlock()
log.Infof("removed mgmt cache domain=%s type=%s: no records returned",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.mutex.Unlock()
log.Debugf("skipping refresh evict for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
} }
return resp.ips, nil
now := time.Now()
m.mutex.Lock()
if m.records[question] != expected {
m.mutex.Unlock()
log.Debugf("skipping refresh write for domain=%s type=%s: entry changed during refresh",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
m.records[question] = &cachedRecord{records: records, cachedAt: now}
m.mutex.Unlock()
log.Infof("refreshed mgmt cache domain=%s type=%s",
d.SafeString(), dns.TypeToString[question.Qtype])
return nil
}
func (m *Resolver) markRefreshing(question dns.Question) {
m.mutex.Lock()
m.refreshing[question] = &atomic.Bool{}
m.mutex.Unlock()
}
func (m *Resolver) clearRefreshing(question dns.Question) {
m.mutex.Lock()
delete(m.refreshing, question)
m.mutex.Unlock()
}
// markRefreshFailed arms the backoff and returns the new consecutive-failure
// count so callers can downgrade subsequent failure logs to debug.
func (m *Resolver) markRefreshFailed(question dns.Question, expected *cachedRecord) int {
m.mutex.Lock()
defer m.mutex.Unlock()
c, ok := m.records[question]
if !ok || c != expected {
return 0
}
c.lastFailedRefresh = time.Now()
c.consecFailures++
return c.consecFailures
}
// lookupBoth resolves A and AAAA via chain or OS. Per-family errors let
// callers tell records, NODATA (nil err, no records), and failure apart.
func (m *Resolver) lookupBoth(ctx context.Context, d domain.Domain, dnsName string) (aRecords, aaaaRecords []dns.RR, errA, errAAAA error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
aRecords, errA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.lookupViaChain(ctx, chain, maxPriority, dnsName, dns.TypeAAAA)
return
}
// TODO: drop once every supported OS registers a fallback resolver. Safe
// today: no root handler at priority ≤ PriorityUpstream means NetBird is
// not the system resolver, so net.DefaultResolver will not loop back.
aRecords, errA = m.osLookup(ctx, d, dnsName, dns.TypeA)
aaaaRecords, errAAAA = m.osLookup(ctx, d, dnsName, dns.TypeAAAA)
return
}
// lookupRecords resolves a single record type via chain or OS. The OS branch
// arms the loop detector for the duration of its call so that ServeDNS can
// spot the OS resolver routing the recursive query back to us.
func (m *Resolver) lookupRecords(ctx context.Context, d domain.Domain, q dns.Question) ([]dns.RR, error) {
m.mutex.RLock()
chain := m.chain
maxPriority := m.chainMaxPriority
m.mutex.RUnlock()
if chain != nil && chain.HasRootHandlerAtOrBelow(maxPriority) {
return m.lookupViaChain(ctx, chain, maxPriority, q.Name, q.Qtype)
}
// TODO: drop once every supported OS registers a fallback resolver.
m.markRefreshing(q)
defer m.clearRefreshing(q)
return m.osLookup(ctx, d, q.Name, q.Qtype)
}
// lookupViaChain resolves via the handler chain and rewrites each RR to use
// dnsName as owner and m.cacheTTL as TTL, so CNAME-backed domains don't cache
// target-owned records or upstream TTLs. NODATA returns (nil, nil).
func (m *Resolver) lookupViaChain(ctx context.Context, chain ChainResolver, maxPriority int, dnsName string, qtype uint16) ([]dns.RR, error) {
msg := &dns.Msg{}
msg.SetQuestion(dnsName, qtype)
msg.RecursionDesired = true
resp, err := chain.ResolveInternal(ctx, msg, maxPriority)
if err != nil {
return nil, fmt.Errorf("chain resolve: %w", err)
}
if resp == nil {
return nil, fmt.Errorf("chain resolve returned nil response")
}
if resp.Rcode != dns.RcodeSuccess {
return nil, fmt.Errorf("chain resolve rcode=%s", dns.RcodeToString[resp.Rcode])
}
ttl := uint32(m.cacheTTL.Seconds())
owners := cnameOwners(dnsName, resp.Answer)
var filtered []dns.RR
for _, rr := range resp.Answer {
h := rr.Header()
if h.Class != dns.ClassINET || h.Rrtype != qtype {
continue
}
if !owners[strings.ToLower(dns.Fqdn(h.Name))] {
continue
}
if cp := cloneIPRecord(rr, dnsName, ttl); cp != nil {
filtered = append(filtered, cp)
}
}
return filtered, nil
}
// osLookup resolves a single family via net.DefaultResolver using resutil,
// which disambiguates NODATA from NXDOMAIN and Unmaps v4-mapped-v6. NODATA
// returns (nil, nil).
func (m *Resolver) osLookup(ctx context.Context, d domain.Domain, dnsName string, qtype uint16) ([]dns.RR, error) {
network := resutil.NetworkForQtype(qtype)
if network == "" {
return nil, fmt.Errorf("unsupported qtype %s", dns.TypeToString[qtype])
}
log.Infof("looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
defer log.Infof("done looking up IP for mgmt domain=%s type=%s", d.SafeString(), dns.TypeToString[qtype])
result := resutil.LookupIP(ctx, net.DefaultResolver, network, d.PunycodeString(), qtype)
if result.Rcode == dns.RcodeSuccess {
return resutil.IPsToRRs(dnsName, result.IPs, uint32(m.cacheTTL.Seconds())), nil
}
if result.Err != nil {
return nil, fmt.Errorf("resolve %s type=%s: %w", d.SafeString(), dns.TypeToString[qtype], result.Err)
}
return nil, fmt.Errorf("resolve %s type=%s: rcode=%s", d.SafeString(), dns.TypeToString[qtype], dns.RcodeToString[result.Rcode])
}
// responseTTL returns the remaining cache lifetime in seconds (rounded up),
// so downstream resolvers don't cache an answer for longer than we will.
func (m *Resolver) responseTTL(cachedAt time.Time) uint32 {
remaining := m.cacheTTL - time.Since(cachedAt)
if remaining <= 0 {
return 0
}
return uint32((remaining + time.Second - 1) / time.Second)
} }
// PopulateFromConfig extracts and caches domains from the client configuration. // PopulateFromConfig extracts and caches domains from the client configuration.
@@ -456,12 +224,19 @@ func (m *Resolver) RemoveDomain(d domain.Domain) error {
m.mutex.Lock() m.mutex.Lock()
defer m.mutex.Unlock() defer m.mutex.Unlock()
qA := dns.Question{Name: dnsName, Qtype: dns.TypeA, Qclass: dns.ClassINET} aQuestion := dns.Question{
qAAAA := dns.Question{Name: dnsName, Qtype: dns.TypeAAAA, Qclass: dns.ClassINET} Name: dnsName,
delete(m.records, qA) Qtype: dns.TypeA,
delete(m.records, qAAAA) Qclass: dns.ClassINET,
delete(m.refreshing, qA) }
delete(m.refreshing, qAAAA) delete(m.records, aQuestion)
aaaaQuestion := dns.Question{
Name: dnsName,
Qtype: dns.TypeAAAA,
Qclass: dns.ClassINET,
}
delete(m.records, aaaaQuestion)
log.Debugf("removed domain=%s from cache", d.SafeString()) log.Debugf("removed domain=%s from cache", d.SafeString())
return nil return nil
@@ -619,73 +394,3 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
return domains return domains
} }
// cloneIPRecord returns a deep copy of rr retargeted to owner with ttl. Non
// A/AAAA records return nil.
func cloneIPRecord(rr dns.RR, owner string, ttl uint32) dns.RR {
switch r := rr.(type) {
case *dns.A:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.A = slices.Clone(r.A)
return &cp
case *dns.AAAA:
cp := *r
cp.Hdr.Name = owner
cp.Hdr.Ttl = ttl
cp.AAAA = slices.Clone(r.AAAA)
return &cp
}
return nil
}
// cloneRecordsWithTTL clones A/AAAA records preserving their owner and
// stamping ttl so the response shares no memory with the cached slice.
func cloneRecordsWithTTL(records []dns.RR, ttl uint32) []dns.RR {
out := make([]dns.RR, 0, len(records))
for _, rr := range records {
if cp := cloneIPRecord(rr, rr.Header().Name, ttl); cp != nil {
out = append(out, cp)
}
}
return out
}
// cnameOwners returns dnsName plus every target reachable by following CNAMEs
// in answer, iterating until fixed point so out-of-order chains resolve.
func cnameOwners(dnsName string, answer []dns.RR) map[string]bool {
owners := map[string]bool{dnsName: true}
for {
added := false
for _, rr := range answer {
cname, ok := rr.(*dns.CNAME)
if !ok {
continue
}
name := strings.ToLower(dns.Fqdn(cname.Hdr.Name))
if !owners[name] {
continue
}
target := strings.ToLower(dns.Fqdn(cname.Target))
if !owners[target] {
owners[target] = true
added = true
}
}
if !added {
return owners
}
}
}
// resolveCacheTTL reads the cache TTL override env var; invalid or empty
// values fall back to defaultTTL. Called once per Resolver from NewResolver.
func resolveCacheTTL() time.Duration {
if v := os.Getenv(envMgmtCacheTTL); v != "" {
if d, err := time.ParseDuration(v); err == nil && d > 0 {
return d
}
}
return defaultTTL
}

View File

@@ -1,408 +0,0 @@
package mgmt
import (
"context"
"errors"
"net"
"sync"
"sync/atomic"
"testing"
"time"
"github.com/miekg/dns"
"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"github.com/netbirdio/netbird/client/internal/dns/test"
"github.com/netbirdio/netbird/shared/management/domain"
)
type fakeChain struct {
mu sync.Mutex
calls map[string]int
answers map[string][]dns.RR
err error
hasRoot bool
onLookup func()
}
func newFakeChain() *fakeChain {
return &fakeChain{
calls: map[string]int{},
answers: map[string][]dns.RR{},
hasRoot: true,
}
}
func (f *fakeChain) HasRootHandlerAtOrBelow(maxPriority int) bool {
f.mu.Lock()
defer f.mu.Unlock()
return f.hasRoot
}
func (f *fakeChain) ResolveInternal(ctx context.Context, msg *dns.Msg, maxPriority int) (*dns.Msg, error) {
f.mu.Lock()
q := msg.Question[0]
key := q.Name + "|" + dns.TypeToString[q.Qtype]
f.calls[key]++
answers := f.answers[key]
err := f.err
onLookup := f.onLookup
f.mu.Unlock()
if onLookup != nil {
onLookup()
}
if err != nil {
return nil, err
}
resp := &dns.Msg{}
resp.SetReply(msg)
resp.Answer = answers
return resp, nil
}
func (f *fakeChain) setAnswer(name string, qtype uint16, ip string) {
f.mu.Lock()
defer f.mu.Unlock()
key := name + "|" + dns.TypeToString[qtype]
hdr := dns.RR_Header{Name: name, Rrtype: qtype, Class: dns.ClassINET, Ttl: 60}
switch qtype {
case dns.TypeA:
f.answers[key] = []dns.RR{&dns.A{Hdr: hdr, A: net.ParseIP(ip).To4()}}
case dns.TypeAAAA:
f.answers[key] = []dns.RR{&dns.AAAA{Hdr: hdr, AAAA: net.ParseIP(ip).To16()}}
}
}
func (f *fakeChain) callCount(name string, qtype uint16) int {
f.mu.Lock()
defer f.mu.Unlock()
return f.calls[name+"|"+dns.TypeToString[qtype]]
}
// waitFor polls the predicate until it returns true or the deadline passes.
func waitFor(t *testing.T, d time.Duration, fn func() bool) {
t.Helper()
deadline := time.Now().Add(d)
for time.Now().Before(deadline) {
if fn() {
return
}
time.Sleep(5 * time.Millisecond)
}
t.Fatalf("condition not met within %s", d)
}
func queryA(t *testing.T, r *Resolver, name string) *dns.Msg {
t.Helper()
msg := new(dns.Msg)
msg.SetQuestion(name, dns.TypeA)
w := &test.MockResponseWriter{}
r.ServeDNS(w, msg)
return w.GetLastResponse()
}
func firstA(t *testing.T, resp *dns.Msg) string {
t.Helper()
require.NotNil(t, resp)
require.Greater(t, len(resp.Answer), 0, "expected at least one answer")
a, ok := resp.Answer[0].(*dns.A)
require.True(t, ok, "expected A record")
return a.A.String()
}
func TestResolver_CacheTTLGatesRefresh(t *testing.T) {
// Same cached entry age, different cacheTTL values: the shorter TTL must
// trigger a background refresh, the longer one must not. Proves that the
// per-Resolver cacheTTL field actually drives the stale decision.
cachedAt := time.Now().Add(-100 * time.Millisecond)
newRec := func() *cachedRecord {
return &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: cachedAt,
}
}
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
t.Run("short TTL treats entry as stale and refreshes", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = 10 * time.Millisecond
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount(q.Name, dns.TypeA) >= 1
})
})
t.Run("long TTL keeps entry fresh and skips refresh", func(t *testing.T) {
r := NewResolver()
r.cacheTTL = time.Hour
chain := newFakeChain()
chain.setAnswer(q.Name, dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[q] = newRec()
resp := queryA(t, r, q.Name)
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 0, chain.callCount(q.Name, dns.TypeA), "fresh entry must not trigger refresh")
})
}
func TestResolver_ServeFresh_NoRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
r.records[dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: "mgmt.example.com.", Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(), // fresh
}
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp))
time.Sleep(20 * time.Millisecond)
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA), "fresh entry must not trigger refresh")
}
func TestResolver_StaleTriggersAsyncRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL), // stale
}
// First query: serves stale immediately.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must be served while refresh runs")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) >= 1
})
// Next query should now return the refreshed IP.
waitFor(t, time.Second, func() bool {
resp := queryA(t, r, "mgmt.example.com.")
return resp != nil && len(resp.Answer) > 0 && firstA(t, resp) == "10.0.0.2"
})
}
func TestResolver_ConcurrentStaleHitsCollapseRefresh(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
var inflight atomic.Int32
var maxInflight atomic.Int32
chain.onLookup = func() {
cur := inflight.Add(1)
defer inflight.Add(-1)
for {
prev := maxInflight.Load()
if cur <= prev || maxInflight.CompareAndSwap(prev, cur) {
break
}
}
time.Sleep(50 * time.Millisecond) // hold inflight long enough to collide
}
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
var wg sync.WaitGroup
for i := 0; i < 50; i++ {
wg.Add(1)
go func() {
defer wg.Done()
queryA(t, r, "mgmt.example.com.")
}()
}
wg.Wait()
waitFor(t, 2*time.Second, func() bool {
return inflight.Load() == 0
})
calls := chain.callCount("mgmt.example.com.", dns.TypeA)
assert.LessOrEqual(t, calls, 2, "singleflight must collapse concurrent refreshes (got %d)", calls)
assert.Equal(t, int32(1), maxInflight.Load(), "only one refresh should run concurrently")
}
func TestResolver_RefreshFailureArmsBackoff(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.err = errors.New("boom")
r.SetChainResolver(chain, 50)
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now().Add(-2 * defaultTTL),
}
// First stale hit triggers a refresh attempt that fails.
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry served while refresh fails")
waitFor(t, time.Second, func() bool {
return chain.callCount("mgmt.example.com.", dns.TypeA) == 1
})
waitFor(t, time.Second, func() bool {
r.mutex.RLock()
defer r.mutex.RUnlock()
c, ok := r.records[q]
return ok && !c.lastFailedRefresh.IsZero()
})
// Subsequent stale hits within backoff window should not schedule more refreshes.
for i := 0; i < 10; i++ {
queryA(t, r, "mgmt.example.com.")
}
time.Sleep(50 * time.Millisecond)
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA), "backoff must suppress further refreshes")
}
func TestResolver_NoRootHandler_SkipsChain(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.hasRoot = false
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
r.SetChainResolver(chain, 50)
// With hasRoot=false the chain must not be consulted. Use a short
// deadline so the OS fallback returns quickly without waiting on a
// real network call in CI.
ctx, cancel := context.WithTimeout(context.Background(), 50*time.Millisecond)
defer cancel()
_, _, _, _ = r.lookupBoth(ctx, domain.Domain("mgmt.example.com"), "mgmt.example.com.")
assert.Equal(t, 0, chain.callCount("mgmt.example.com.", dns.TypeA),
"chain must not be used when no root handler is registered at the bound priority")
}
func TestResolver_ServeDuringRefreshSetsLoopFlag(t *testing.T) {
// ServeDNS being invoked for a question while a refresh for that question
// is inflight indicates a resolver loop (OS resolver sent the recursive
// query back to us). The inflightRefresh.loopLoggedOnce flag must be set.
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
// Simulate an inflight refresh.
r.markRefreshing(q)
defer r.clearRefreshing(q)
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.1", firstA(t, resp), "stale entry must still be served to avoid breaking external queries")
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
require.NotNil(t, inflight)
assert.True(t, inflight.Load(), "loop flag must be set once a ServeDNS during refresh was observed")
}
func TestResolver_LoopFlagOnlyTrippedOncePerRefresh(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
r.markRefreshing(q)
defer r.clearRefreshing(q)
// Multiple ServeDNS calls during the same refresh must not re-set the flag
// (CompareAndSwap from false -> true returns true only on the first call).
for range 5 {
queryA(t, r, "mgmt.example.com.")
}
r.mutex.RLock()
inflight := r.refreshing[q]
r.mutex.RUnlock()
assert.True(t, inflight.Load())
}
func TestResolver_NoLoopFlagWhenNotRefreshing(t *testing.T) {
r := NewResolver()
q := dns.Question{Name: "mgmt.example.com.", Qtype: dns.TypeA, Qclass: dns.ClassINET}
r.records[q] = &cachedRecord{
records: []dns.RR{&dns.A{
Hdr: dns.RR_Header{Name: q.Name, Rrtype: dns.TypeA, Class: dns.ClassINET, Ttl: 60},
A: net.ParseIP("10.0.0.1").To4(),
}},
cachedAt: time.Now(),
}
queryA(t, r, "mgmt.example.com.")
r.mutex.RLock()
_, ok := r.refreshing[q]
r.mutex.RUnlock()
assert.False(t, ok, "no refresh inflight means no loop tracking")
}
func TestResolver_AddDomain_UsesChainWhenRootRegistered(t *testing.T) {
r := NewResolver()
chain := newFakeChain()
chain.setAnswer("mgmt.example.com.", dns.TypeA, "10.0.0.2")
chain.setAnswer("mgmt.example.com.", dns.TypeAAAA, "fd00::2")
r.SetChainResolver(chain, 50)
require.NoError(t, r.AddDomain(context.Background(), domain.Domain("mgmt.example.com")))
resp := queryA(t, r, "mgmt.example.com.")
assert.Equal(t, "10.0.0.2", firstA(t, resp))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeA))
assert.Equal(t, 1, chain.callCount("mgmt.example.com.", dns.TypeAAAA))
}

View File

@@ -6,7 +6,6 @@ import (
"net/url" "net/url"
"strings" "strings"
"testing" "testing"
"time"
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
@@ -24,60 +23,6 @@ func TestResolver_NewResolver(t *testing.T) {
assert.False(t, resolver.MatchSubdomains()) assert.False(t, resolver.MatchSubdomains())
} }
func TestResolveCacheTTL(t *testing.T) {
tests := []struct {
name string
value string
want time.Duration
}{
{"unset falls back to default", "", defaultTTL},
{"valid duration", "45s", 45 * time.Second},
{"valid minutes", "2m", 2 * time.Minute},
{"malformed falls back to default", "not-a-duration", defaultTTL},
{"zero falls back to default", "0s", defaultTTL},
{"negative falls back to default", "-5s", defaultTTL},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
t.Setenv(envMgmtCacheTTL, tc.value)
got := resolveCacheTTL()
assert.Equal(t, tc.want, got, "parsed TTL should match")
})
}
}
func TestNewResolver_CacheTTLFromEnv(t *testing.T) {
t.Setenv(envMgmtCacheTTL, "7s")
r := NewResolver()
assert.Equal(t, 7*time.Second, r.cacheTTL, "NewResolver should evaluate cacheTTL once from env")
}
func TestResolver_ResponseTTL(t *testing.T) {
now := time.Now()
tests := []struct {
name string
cacheTTL time.Duration
cachedAt time.Time
wantMin uint32
wantMax uint32
}{
{"fresh entry returns full TTL", 60 * time.Second, now, 59, 60},
{"half-aged entry returns half TTL", 60 * time.Second, now.Add(-30 * time.Second), 29, 31},
{"expired entry returns zero", 60 * time.Second, now.Add(-61 * time.Second), 0, 0},
{"exactly expired returns zero", 10 * time.Second, now.Add(-10 * time.Second), 0, 0},
}
for _, tc := range tests {
t.Run(tc.name, func(t *testing.T) {
r := &Resolver{cacheTTL: tc.cacheTTL}
got := r.responseTTL(tc.cachedAt)
assert.GreaterOrEqual(t, got, tc.wantMin, "remaining TTL should be >= wantMin")
assert.LessOrEqual(t, got, tc.wantMax, "remaining TTL should be <= wantMax")
})
}
}
func TestResolver_ExtractDomainFromURL(t *testing.T) { func TestResolver_ExtractDomainFromURL(t *testing.T) {
tests := []struct { tests := []struct {
name string name string

View File

@@ -212,7 +212,6 @@ func newDefaultServer(
ctx, stop := context.WithCancel(ctx) ctx, stop := context.WithCancel(ctx)
mgmtCacheResolver := mgmt.NewResolver() mgmtCacheResolver := mgmt.NewResolver()
mgmtCacheResolver.SetChainResolver(handlerChain, PriorityUpstream)
defaultServer := &DefaultServer{ defaultServer := &DefaultServer{
ctx: ctx, ctx: ctx,

View File

@@ -28,7 +28,6 @@ import (
"github.com/netbirdio/netbird/client/firewall" "github.com/netbirdio/netbird/client/firewall"
"github.com/netbirdio/netbird/client/firewall/firewalld" "github.com/netbirdio/netbird/client/firewall/firewalld"
firewallManager "github.com/netbirdio/netbird/client/firewall/manager" firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/iface/device" "github.com/netbirdio/netbird/client/iface/device"
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack" nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
@@ -69,7 +68,6 @@ import (
signal "github.com/netbirdio/netbird/shared/signal/client" signal "github.com/netbirdio/netbird/shared/signal/client"
sProto "github.com/netbirdio/netbird/shared/signal/proto" sProto "github.com/netbirdio/netbird/shared/signal/proto"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
"github.com/netbirdio/netbird/util/capture"
) )
// PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer. // PeerConnectionTimeoutMax is a timeout of an initial connection attempt to a remote peer.
@@ -220,8 +218,6 @@ type Engine struct {
portForwardManager *portforward.Manager portForwardManager *portforward.Manager
srWatcher *guard.SRWatcher srWatcher *guard.SRWatcher
afpacketCapture *capture.AFPacketCapture
// Sync response persistence (protected by syncRespMux) // Sync response persistence (protected by syncRespMux)
syncRespMux sync.RWMutex syncRespMux sync.RWMutex
persistSyncResponse bool persistSyncResponse bool
@@ -948,12 +944,7 @@ func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
return fmt.Errorf("update relay token: %w", err) return fmt.Errorf("update relay token: %w", err)
} }
urls := update.Urls e.relayManager.UpdateServerURLs(update.Urls)
if override, ok := peer.OverrideRelayURLs(); ok {
log.Infof("overriding relay URLs from %s: %v", peer.EnvKeyNBHomeRelayServers, override)
urls = override
}
e.relayManager.UpdateServerURLs(urls)
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled. // Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
// We can ignore all errors because the guard will manage the reconnection retries. // We can ignore all errors because the guard will manage the reconnection retries.
@@ -1707,11 +1698,6 @@ func (e *Engine) parseNATExternalIPMappings() []string {
} }
func (e *Engine) close() { func (e *Engine) close() {
if e.afpacketCapture != nil {
e.afpacketCapture.Stop()
e.afpacketCapture = nil
}
log.Debugf("removing Netbird interface %s", e.config.WgIfaceName) log.Debugf("removing Netbird interface %s", e.config.WgIfaceName)
if e.wgInterface != nil { if e.wgInterface != nil {
@@ -2177,62 +2163,6 @@ func (e *Engine) Address() (netip.Addr, error) {
return e.wgInterface.Address().IP, nil return e.wgInterface.Address().IP, nil
} }
// SetCapture sets or clears packet capture on the WireGuard device.
// On userspace WireGuard, it taps the FilteredDevice directly.
// On kernel WireGuard (Linux), it falls back to AF_PACKET raw socket capture.
// Pass nil to disable capture.
func (e *Engine) SetCapture(pc device.PacketCapture) error {
e.syncMsgMux.Lock()
defer e.syncMsgMux.Unlock()
intf := e.wgInterface
if intf == nil {
return errors.New("wireguard interface not initialized")
}
if e.afpacketCapture != nil {
e.afpacketCapture.Stop()
e.afpacketCapture = nil
}
dev := intf.GetDevice()
if dev != nil {
dev.SetCapture(pc)
e.setForwarderCapture(pc)
return nil
}
// Kernel mode: no FilteredDevice. Use AF_PACKET on Linux.
if pc == nil {
return nil
}
sess, ok := pc.(*capture.Session)
if !ok {
return errors.New("filtered device not available and AF_PACKET requires *capture.Session")
}
afc := capture.NewAFPacketCapture(intf.Name(), sess)
if err := afc.Start(); err != nil {
return fmt.Errorf("start AF_PACKET capture on %s: %w", intf.Name(), err)
}
e.afpacketCapture = afc
return nil
}
// setForwarderCapture propagates capture to the USP filter's forwarder endpoint.
// This captures outbound response packets that bypass the FilteredDevice in netstack mode.
func (e *Engine) setForwarderCapture(pc device.PacketCapture) {
if e.firewall == nil {
return
}
type forwarderCapturer interface {
SetPacketCapture(pc forwarder.PacketCapture)
}
if fc, ok := e.firewall.(forwarderCapturer); ok {
fc.SetPacketCapture(pc)
}
}
func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) { func (e *Engine) updateForwardRules(rules []*mgmProto.ForwardingRule) ([]firewallManager.ForwardRule, error) {
if e.firewall == nil { if e.firewall == nil {
log.Warn("firewall is disabled, not updating forwarding rules") log.Warn("firewall is disabled, not updating forwarding rules")
@@ -2454,8 +2384,6 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
} }
} }
relayIP := decodeRelayIP(msg.GetBody().GetRelayServerIP())
offerAnswer := peer.OfferAnswer{ offerAnswer := peer.OfferAnswer{
IceCredentials: peer.IceCredentials{ IceCredentials: peer.IceCredentials{
UFrag: remoteCred.UFrag, UFrag: remoteCred.UFrag,
@@ -2466,23 +2394,7 @@ func convertToOfferAnswer(msg *sProto.Message) (*peer.OfferAnswer, error) {
RosenpassPubKey: rosenpassPubKey, RosenpassPubKey: rosenpassPubKey,
RosenpassAddr: rosenpassAddr, RosenpassAddr: rosenpassAddr,
RelaySrvAddress: msg.GetBody().GetRelayServerAddress(), RelaySrvAddress: msg.GetBody().GetRelayServerAddress(),
RelaySrvIP: relayIP,
SessionID: sessionID, SessionID: sessionID,
} }
return &offerAnswer, nil return &offerAnswer, nil
} }
// decodeRelayIP decodes the proto relayServerIP bytes (4 or 16) into a
// netip.Addr. Returns the zero value for empty input and logs a warning
// for malformed payloads.
func decodeRelayIP(b []byte) netip.Addr {
if len(b) == 0 {
return netip.Addr{}
}
ip, ok := netip.AddrFromSlice(b)
if !ok {
log.Warnf("invalid relayServerIP in signal message (%d bytes), ignoring", len(b))
return netip.Addr{}
}
return ip.Unmap()
}

View File

@@ -1671,7 +1671,7 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -3,6 +3,7 @@ package activity
import ( import (
"net" "net"
"net/netip" "net/netip"
"runtime"
"testing" "testing"
"time" "time"
@@ -17,6 +18,10 @@ import (
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
) )
func isBindListenerPlatform() bool {
return runtime.GOOS == "windows" || runtime.GOOS == "js"
}
// mockEndpointManager implements device.EndpointManager for testing // mockEndpointManager implements device.EndpointManager for testing
type mockEndpointManager struct { type mockEndpointManager struct {
endpoints map[netip.Addr]net.Conn endpoints map[netip.Addr]net.Conn
@@ -176,6 +181,10 @@ func TestBindListener_Close(t *testing.T) {
} }
func TestManager_BindMode(t *testing.T) { func TestManager_BindMode(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager() mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}
@@ -217,6 +226,10 @@ func TestManager_BindMode(t *testing.T) {
} }
func TestManager_BindMode_MultiplePeers(t *testing.T) { func TestManager_BindMode_MultiplePeers(t *testing.T) {
if !isBindListenerPlatform() {
t.Skip("BindListener only used on Windows/JS platforms")
}
mockEndpointMgr := newMockEndpointManager() mockEndpointMgr := newMockEndpointManager()
mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr} mockIface := &MockWGIfaceBind{endpointMgr: mockEndpointMgr}

View File

@@ -4,12 +4,14 @@ import (
"errors" "errors"
"net" "net"
"net/netip" "net/netip"
"runtime"
"sync" "sync"
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.zx2c4.com/wireguard/wgctrl/wgtypes" "golang.zx2c4.com/wireguard/wgctrl/wgtypes"
"github.com/netbirdio/netbird/client/iface/netstack"
"github.com/netbirdio/netbird/client/iface/wgaddr" "github.com/netbirdio/netbird/client/iface/wgaddr"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
peerid "github.com/netbirdio/netbird/client/internal/peer/id" peerid "github.com/netbirdio/netbird/client/internal/peer/id"
@@ -73,6 +75,16 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
return NewUDPListener(m.wgIface, peerCfg) return NewUDPListener(m.wgIface, peerCfg)
} }
// BindListener is used on Windows, JS, and netstack platforms:
// - JS: Cannot listen to UDP sockets
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
// gateway points to, preventing them from reaching the loopback interface.
// - Netstack: Allows multiple instances on the same host without port conflicts.
// BindListener bypasses these issues by passing data directly through the bind.
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
return NewUDPListener(m.wgIface, peerCfg)
}
provider, ok := m.wgIface.(bindProvider) provider, ok := m.wgIface.(bindProvider)
if !ok { if !ok {
return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider") return nil, errors.New("interface claims userspace bind but doesn't implement bindProvider")

View File

@@ -6,6 +6,7 @@ import (
"time" "time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/internal/lazyconn" "github.com/netbirdio/netbird/client/internal/lazyconn"
"github.com/netbirdio/netbird/client/internal/lazyconn/activity" "github.com/netbirdio/netbird/client/internal/lazyconn/activity"
@@ -90,8 +91,8 @@ func (m *Manager) UpdateRouteHAMap(haMap route.HAMap) {
m.routesMu.Lock() m.routesMu.Lock()
defer m.routesMu.Unlock() defer m.routesMu.Unlock()
clear(m.peerToHAGroups) maps.Clear(m.peerToHAGroups)
clear(m.haGroupToPeers) maps.Clear(m.haGroupToPeers)
for haUniqueID, routes := range haMap { for haUniqueID, routes := range haMap {
var peers []string var peers []string

View File

@@ -3,6 +3,8 @@ package store
import ( import (
"sync" "sync"
"golang.org/x/exp/maps"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/netbirdio/netbird/client/internal/netflow/types" "github.com/netbirdio/netbird/client/internal/netflow/types"
@@ -28,7 +30,7 @@ func (m *Memory) StoreEvent(event *types.Event) {
func (m *Memory) Close() { func (m *Memory) Close() {
m.mux.Lock() m.mux.Lock()
defer m.mux.Unlock() defer m.mux.Unlock()
clear(m.events) maps.Clear(m.events)
} }
func (m *Memory) GetEvents() []*types.Event { func (m *Memory) GetEvents() []*types.Event {

View File

@@ -7,8 +7,7 @@ import (
) )
const ( const (
EnvKeyNBForceRelay = "NB_FORCE_RELAY" EnvKeyNBForceRelay = "NB_FORCE_RELAY"
EnvKeyNBHomeRelayServers = "NB_HOME_RELAY_SERVERS"
) )
func IsForceRelayed() bool { func IsForceRelayed() bool {
@@ -17,28 +16,3 @@ func IsForceRelayed() bool {
} }
return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true") return strings.EqualFold(os.Getenv(EnvKeyNBForceRelay), "true")
} }
// OverrideRelayURLs returns the relay server URL list set in
// NB_HOME_RELAY_SERVERS (comma-separated) and a boolean indicating whether
// the override is active. When the env var is unset, the boolean is false
// and the caller should keep the list received from the management server.
// Intended for lab/debug scenarios where a peer must pin to a specific home
// relay regardless of what management offers.
func OverrideRelayURLs() ([]string, bool) {
raw := os.Getenv(EnvKeyNBHomeRelayServers)
if raw == "" {
return nil, false
}
parts := strings.Split(raw, ",")
urls := make([]string, 0, len(parts))
for _, p := range parts {
p = strings.TrimSpace(p)
if p != "" {
urls = append(urls, p)
}
}
if len(urls) == 0 {
return nil, false
}
return urls, true
}

View File

@@ -3,7 +3,6 @@ package peer
import ( import (
"context" "context"
"errors" "errors"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -41,10 +40,6 @@ type OfferAnswer struct {
// relay server address // relay server address
RelaySrvAddress string RelaySrvAddress string
// RelaySrvIP is the IP the remote peer is connected to on its
// relay server. Used as a dial target if DNS for RelaySrvAddress
// fails. Zero value if the peer did not advertise an IP.
RelaySrvIP netip.Addr
// SessionID is the unique identifier of the session, used to discard old messages // SessionID is the unique identifier of the session, used to discard old messages
SessionID *ICESessionID SessionID *ICESessionID
} }
@@ -222,9 +217,8 @@ func (h *Handshaker) buildOfferAnswer() OfferAnswer {
answer.SessionID = &sid answer.SessionID = &sid
} }
if addr, ip, err := h.relay.RelayInstanceAddress(); err == nil { if addr, err := h.relay.RelayInstanceAddress(); err == nil {
answer.RelaySrvAddress = addr answer.RelaySrvAddress = addr
answer.RelaySrvIP = ip
} }
return answer return answer

View File

@@ -8,7 +8,6 @@ import (
type mocListener struct { type mocListener struct {
lastState int lastState int
wg sync.WaitGroup wg sync.WaitGroup
peersWg sync.WaitGroup
peers int peers int
} }
@@ -34,7 +33,6 @@ func (l *mocListener) OnAddressChanged(host, addr string) {
} }
func (l *mocListener) OnPeersListChanged(size int) { func (l *mocListener) OnPeersListChanged(size int) {
l.peers = size l.peers = size
l.peersWg.Done()
} }
func (l *mocListener) setWaiter() { func (l *mocListener) setWaiter() {
@@ -45,14 +43,6 @@ func (l *mocListener) wait() {
l.wg.Wait() l.wg.Wait()
} }
func (l *mocListener) setPeersWaiter() {
l.peersWg.Add(1)
}
func (l *mocListener) waitPeers() {
l.peersWg.Wait()
}
func Test_notifier_serverState(t *testing.T) { func Test_notifier_serverState(t *testing.T) {
type scenario struct { type scenario struct {
@@ -82,13 +72,11 @@ func Test_notifier_serverState(t *testing.T) {
func Test_notifier_SetListener(t *testing.T) { func Test_notifier_SetListener(t *testing.T) {
listener := &mocListener{} listener := &mocListener{}
listener.setWaiter() listener.setWaiter()
listener.setPeersWaiter()
n := newNotifier() n := newNotifier()
n.lastNotification = stateConnecting n.lastNotification = stateConnecting
n.setListener(listener) n.setListener(listener)
listener.wait() listener.wait()
listener.waitPeers()
if listener.lastState != n.lastNotification { if listener.lastState != n.lastNotification {
t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification) t.Errorf("invalid state: %d, expected: %d", listener.lastState, n.lastNotification)
} }
@@ -97,14 +85,9 @@ func Test_notifier_SetListener(t *testing.T) {
func Test_notifier_RemoveListener(t *testing.T) { func Test_notifier_RemoveListener(t *testing.T) {
listener := &mocListener{} listener := &mocListener{}
listener.setWaiter() listener.setWaiter()
listener.setPeersWaiter()
n := newNotifier() n := newNotifier()
n.lastNotification = stateConnecting n.lastNotification = stateConnecting
n.setListener(listener) n.setListener(listener)
// setListener replays cached state on a goroutine; wait for both the state
// and peers callbacks to finish so we don't race on listener.peers.
listener.wait()
listener.waitPeers()
n.removeListener() n.removeListener()
n.peerListChanged(1) n.peerListChanged(1)

View File

@@ -54,19 +54,19 @@ func (s *Signaler) signalOfferAnswer(offerAnswer OfferAnswer, remoteKey string,
log.Warnf("failed to get session ID bytes: %v", err) log.Warnf("failed to get session ID bytes: %v", err)
} }
} }
msg, err := signal.MarshalCredential(s.wgPrivateKey, remoteKey, signal.CredentialPayload{ msg, err := signal.MarshalCredential(
Type: bodyType, s.wgPrivateKey,
WgListenPort: offerAnswer.WgListenPort, offerAnswer.WgListenPort,
Credential: &signal.Credential{ remoteKey,
&signal.Credential{
UFrag: offerAnswer.IceCredentials.UFrag, UFrag: offerAnswer.IceCredentials.UFrag,
Pwd: offerAnswer.IceCredentials.Pwd, Pwd: offerAnswer.IceCredentials.Pwd,
}, },
RosenpassPubKey: offerAnswer.RosenpassPubKey, bodyType,
RosenpassAddr: offerAnswer.RosenpassAddr, offerAnswer.RosenpassPubKey,
RelaySrvAddress: offerAnswer.RelaySrvAddress, offerAnswer.RosenpassAddr,
RelaySrvIP: offerAnswer.RelaySrvIP, offerAnswer.RelaySrvAddress,
SessionID: sessionIDBytes, sessionIDBytes)
})
if err != nil { if err != nil {
return err return err
} }

View File

@@ -320,10 +320,10 @@ func (d *Status) RemovePeer(peerPubKey string) error {
// UpdatePeerState updates peer status // UpdatePeerState updates peer status
func (d *Status) UpdatePeerState(receivedState State) error { func (d *Status) UpdatePeerState(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -343,29 +343,23 @@ func (d *Status) UpdatePeerState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
// when we close the connection we will not notify the router manager d.notifyPeerListChanged()
notifyRouter := receivedState.ConnStatus == StatusIdle
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
} }
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) // when we close the connection we will not notify the router manager
if receivedState.ConnStatus == StatusIdle {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
} }
return nil return nil
} }
func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error { func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.ResID) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer] peerState, ok := d.peers[peer]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -377,20 +371,17 @@ func (d *Status) AddPeerStateRoute(peer string, route string, resourceId route.R
d.routeIDLookup.AddRemoteRouteID(resourceId, pref) d.routeIDLookup.AddRemoteRouteID(resourceId, pref)
} }
numPeers := d.numOfPeers()
d.mux.Unlock()
// todo: consider to make sense of this notification or not // todo: consider to make sense of this notification or not
d.notifier.peerListChanged(numPeers) d.notifyPeerListChanged()
return nil return nil
} }
func (d *Status) RemovePeerStateRoute(peer string, route string) error { func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[peer] peerState, ok := d.peers[peer]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -402,11 +393,8 @@ func (d *Status) RemovePeerStateRoute(peer string, route string) error {
d.routeIDLookup.RemoveRemoteRouteID(pref) d.routeIDLookup.RemoveRemoteRouteID(pref)
} }
numPeers := d.numOfPeers()
d.mux.Unlock()
// todo: consider to make sense of this notification or not // todo: consider to make sense of this notification or not
d.notifier.peerListChanged(numPeers) d.notifyPeerListChanged()
return nil return nil
} }
@@ -422,10 +410,10 @@ func (d *Status) CheckRoutes(ip netip.Addr) ([]byte, bool) {
func (d *Status) UpdatePeerICEState(receivedState State) error { func (d *Status) UpdatePeerICEState(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -443,28 +431,22 @@ func (d *Status) UpdatePeerICEState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) d.notifyPeerListChanged()
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
} }
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
} }
return nil return nil
} }
func (d *Status) UpdatePeerRelayedState(receivedState State) error { func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -479,28 +461,22 @@ func (d *Status) UpdatePeerRelayedState(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) d.notifyPeerListChanged()
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
} }
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
} }
return nil return nil
} }
func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error { func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -514,28 +490,22 @@ func (d *Status) UpdatePeerRelayedStateToDisconnected(receivedState State) error
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) d.notifyPeerListChanged()
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
} }
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
} }
return nil return nil
} }
func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error { func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
peerState, ok := d.peers[receivedState.PubKey] peerState, ok := d.peers[receivedState.PubKey]
if !ok { if !ok {
d.mux.Unlock()
return errors.New("peer doesn't exist") return errors.New("peer doesn't exist")
} }
@@ -552,18 +522,12 @@ func (d *Status) UpdatePeerICEStateToDisconnected(receivedState State) error {
d.peers[receivedState.PubKey] = peerState d.peers[receivedState.PubKey] = peerState
notifyList := hasConnStatusChanged(oldState, receivedState.ConnStatus) if hasConnStatusChanged(oldState, receivedState.ConnStatus) {
notifyRouter := hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) d.notifyPeerListChanged()
routerSnapshot := d.snapshotRouterPeersLocked(receivedState.PubKey, notifyRouter)
numPeers := d.numOfPeers()
d.mux.Unlock()
if notifyList {
d.notifier.peerListChanged(numPeers)
} }
if notifyRouter {
d.dispatchRouterPeers(receivedState.PubKey, routerSnapshot) if hasStatusOrRelayedChange(oldState, receivedState.ConnStatus, oldIsRelayed, receivedState.Relayed) {
d.notifyPeerStateChangeListeners(receivedState.PubKey)
} }
return nil return nil
} }
@@ -630,33 +594,17 @@ func (d *Status) UpdatePeerSSHHostKey(peerPubKey string, sshHostKey []byte) erro
// FinishPeerListModifications this event invoke the notification // FinishPeerListModifications this event invoke the notification
func (d *Status) FinishPeerListModifications() { func (d *Status) FinishPeerListModifications() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
if !d.peerListChangedForNotification { if !d.peerListChangedForNotification {
d.mux.Unlock()
return return
} }
d.peerListChangedForNotification = false d.peerListChangedForNotification = false
numPeers := d.numOfPeers() d.notifyPeerListChanged()
// snapshot per-peer router state to deliver after the lock is released
type routerDispatch struct {
peerID string
snapshot map[string]RouterState
}
dispatches := make([]routerDispatch, 0, len(d.peers))
for key := range d.peers { for key := range d.peers {
snapshot := d.snapshotRouterPeersLocked(key, true) d.notifyPeerStateChangeListeners(key)
if snapshot != nil {
dispatches = append(dispatches, routerDispatch{peerID: key, snapshot: snapshot})
}
}
d.mux.Unlock()
d.notifier.peerListChanged(numPeers)
for _, rd := range dispatches {
d.dispatchRouterPeers(rd.peerID, rd.snapshot)
} }
} }
@@ -707,12 +655,10 @@ func (d *Status) GetLocalPeerState() LocalPeerState {
// UpdateLocalPeerState updates local peer status // UpdateLocalPeerState updates local peer status
func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) { func (d *Status) UpdateLocalPeerState(localPeerState LocalPeerState) {
d.mux.Lock() d.mux.Lock()
d.localPeer = localPeerState defer d.mux.Unlock()
fqdn := d.localPeer.FQDN
ip := d.localPeer.IP
d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip) d.localPeer = localPeerState
d.notifyAddressChanged()
} }
// AddLocalPeerStateRoute adds a route to the local peer state // AddLocalPeerStateRoute adds a route to the local peer state
@@ -775,36 +721,30 @@ func (d *Status) CleanLocalPeerStateRoutes() {
// CleanLocalPeerState cleans local peer status // CleanLocalPeerState cleans local peer status
func (d *Status) CleanLocalPeerState() { func (d *Status) CleanLocalPeerState() {
d.mux.Lock() d.mux.Lock()
d.localPeer = LocalPeerState{} defer d.mux.Unlock()
fqdn := d.localPeer.FQDN
ip := d.localPeer.IP
d.mux.Unlock()
d.notifier.localAddressChanged(fqdn, ip) d.localPeer = LocalPeerState{}
d.notifyAddressChanged()
} }
// MarkManagementDisconnected sets ManagementState to disconnected // MarkManagementDisconnected sets ManagementState to disconnected
func (d *Status) MarkManagementDisconnected(err error) { func (d *Status) MarkManagementDisconnected(err error) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.managementState = false d.managementState = false
d.managementError = err d.managementError = err
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
} }
// MarkManagementConnected sets ManagementState to connected // MarkManagementConnected sets ManagementState to connected
func (d *Status) MarkManagementConnected() { func (d *Status) MarkManagementConnected() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.managementState = true d.managementState = true
d.managementError = nil d.managementError = nil
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
} }
// UpdateSignalAddress update the address of the signal server // UpdateSignalAddress update the address of the signal server
@@ -838,25 +778,21 @@ func (d *Status) UpdateLazyConnection(enabled bool) {
// MarkSignalDisconnected sets SignalState to disconnected // MarkSignalDisconnected sets SignalState to disconnected
func (d *Status) MarkSignalDisconnected(err error) { func (d *Status) MarkSignalDisconnected(err error) {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.signalState = false d.signalState = false
d.signalError = err d.signalError = err
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
} }
// MarkSignalConnected sets SignalState to connected // MarkSignalConnected sets SignalState to connected
func (d *Status) MarkSignalConnected() { func (d *Status) MarkSignalConnected() {
d.mux.Lock() d.mux.Lock()
defer d.mux.Unlock()
defer d.onConnectionChanged()
d.signalState = true d.signalState = true
d.signalError = nil d.signalError = nil
mgm := d.managementState
sig := d.signalState
d.mux.Unlock()
d.notifier.updateServerStates(mgm, sig)
} }
func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) { func (d *Status) UpdateRelayStates(relayResults []relay.ProbeResult) {
@@ -983,7 +919,7 @@ func (d *Status) GetRelayStates() []relay.ProbeResult {
// if the server connection is not established then we will use the general address // if the server connection is not established then we will use the general address
// in case of connection we will use the instance specific address // in case of connection we will use the instance specific address
instanceAddr, _, err := d.relayMgr.RelayInstanceAddress() instanceAddr, err := d.relayMgr.RelayInstanceAddress()
if err != nil { if err != nil {
// TODO add their status // TODO add their status
for _, r := range d.relayMgr.ServerURLs() { for _, r := range d.relayMgr.ServerURLs() {
@@ -1076,17 +1012,18 @@ func (d *Status) RemoveConnectionListener() {
d.notifier.removeListener() d.notifier.removeListener()
} }
// snapshotRouterPeersLocked builds the RouterState map for a peer's subscribers. func (d *Status) onConnectionChanged() {
// Caller MUST hold d.mux. Returns nil when there are no subscribers for peerID d.notifier.updateServerStates(d.managementState, d.signalState)
// or when notify is false. The snapshot is consumed later by dispatchRouterPeers }
// outside the lock so the channel send cannot stall any d.mux holder.
func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[string]RouterState { // notifyPeerStateChangeListeners notifies route manager about the change in peer state
if !notify { func (d *Status) notifyPeerStateChangeListeners(peerID string) {
return nil subs, ok := d.changeNotify[peerID]
} if !ok {
if _, ok := d.changeNotify[peerID]; !ok { return
return nil
} }
// collect the relevant data for router peers
routerPeers := make(map[string]RouterState, len(d.changeNotify)) routerPeers := make(map[string]RouterState, len(d.changeNotify))
for pid := range d.changeNotify { for pid := range d.changeNotify {
s, ok := d.peers[pid] s, ok := d.peers[pid]
@@ -1094,35 +1031,13 @@ func (d *Status) snapshotRouterPeersLocked(peerID string, notify bool) map[strin
log.Warnf("router peer not found in peers list: %s", pid) log.Warnf("router peer not found in peers list: %s", pid)
continue continue
} }
routerPeers[pid] = RouterState{ routerPeers[pid] = RouterState{
Status: s.ConnStatus, Status: s.ConnStatus,
Relayed: s.Relayed, Relayed: s.Relayed,
Latency: s.Latency, Latency: s.Latency,
} }
} }
return routerPeers
}
// dispatchRouterPeers delivers a previously snapshotted router-state map to
// the peer's subscribers. Caller MUST NOT hold d.mux. The method takes a
// fresh, short read of d.changeNotify under the lock to grab subscriber
// channels, then sends outside the lock so a slow consumer cannot block other
// d.mux holders. The send itself stays blocking (only short-circuited by the
// subscriber's context) so peer state transitions are not silently dropped.
func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]RouterState) {
if routerPeers == nil {
return
}
d.mux.Lock()
subsMap, ok := d.changeNotify[peerID]
subs := make([]*StatusChangeSubscription, 0, len(subsMap))
if ok {
for _, sub := range subsMap {
subs = append(subs, sub)
}
}
d.mux.Unlock()
for _, sub := range subs { for _, sub := range subs {
select { select {
@@ -1132,6 +1047,14 @@ func (d *Status) dispatchRouterPeers(peerID string, routerPeers map[string]Route
} }
} }
func (d *Status) notifyPeerListChanged() {
d.notifier.peerListChanged(d.numOfPeers())
}
func (d *Status) notifyAddressChanged() {
d.notifier.localAddressChanged(d.localPeer.FQDN, d.localPeer.IP)
}
func (d *Status) numOfPeers() int { func (d *Status) numOfPeers() int {
return len(d.peers) + len(d.offlinePeers) return len(d.peers) + len(d.offlinePeers)
} }

View File

@@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"net" "net"
"net/netip"
"sync" "sync"
"sync/atomic" "sync/atomic"
@@ -54,19 +53,15 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
w.relaySupportedOnRemotePeer.Store(true) w.relaySupportedOnRemotePeer.Store(true)
// the relayManager will return with error in case if the connection has lost with relay server // the relayManager will return with error in case if the connection has lost with relay server
currentRelayAddress, _, err := w.relayManager.RelayInstanceAddress() currentRelayAddress, err := w.relayManager.RelayInstanceAddress()
if err != nil { if err != nil {
w.log.Errorf("failed to handle new offer: %s", err) w.log.Errorf("failed to handle new offer: %s", err)
return return
} }
srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress) srv := w.preferredRelayServer(currentRelayAddress, remoteOfferAnswer.RelaySrvAddress)
var serverIP netip.Addr
if srv == remoteOfferAnswer.RelaySrvAddress {
serverIP = remoteOfferAnswer.RelaySrvIP
}
relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key, serverIP) relayedConn, err := w.relayManager.OpenConn(w.peerCtx, srv, w.config.Key)
if err != nil { if err != nil {
if errors.Is(err, relayClient.ErrConnAlreadyExists) { if errors.Is(err, relayClient.ErrConnAlreadyExists) {
w.log.Debugf("handled offer by reusing existing relay connection") w.log.Debugf("handled offer by reusing existing relay connection")
@@ -95,7 +90,7 @@ func (w *WorkerRelay) OnNewOffer(remoteOfferAnswer *OfferAnswer) {
}) })
} }
func (w *WorkerRelay) RelayInstanceAddress() (string, netip.Addr, error) { func (w *WorkerRelay) RelayInstanceAddress() (string, error) {
return w.relayManager.RelayInstanceAddress() return w.relayManager.RelayInstanceAddress()
} }

View File

@@ -89,16 +89,8 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec) return false, fmt.Errorf("unusable default nexthop for %s (no interface)", unspec)
} }
reused := false
if err := r.addScopedDefault(unspec, nexthop); err != nil { if err := r.addScopedDefault(unspec, nexthop); err != nil {
if !errors.Is(err, unix.EEXIST) { return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
return false, fmt.Errorf("add scoped default on %s: %w", nexthop.Intf.Name, err)
}
// macOS installs its own RTF_IFSCOPE defaults for primary service
// selection on multi-NIC setups, so a route on this ifindex can
// already exist before we try. Binding to it via IP[V6]_BOUND_IF
// still produces the scoped lookup we need.
reused = true
} }
af := unix.AF_INET af := unix.AF_INET
@@ -110,11 +102,7 @@ func (r *SysOps) installScopedDefaultFor(unspec netip.Addr) (bool, error) {
if nexthop.IP.IsValid() { if nexthop.IP.IsValid() {
via = nexthop.IP.String() via = nexthop.IP.String()
} }
verb := "installed" log.Infof("installed scoped default route via %s on %s for %s", via, nexthop.Intf.Name, afOf(unspec))
if reused {
verb = "reused existing"
}
log.Infof("%s scoped default route via %s on %s for %s", verb, via, nexthop.Intf.Name, afOf(unspec))
return true, nil return true, nil
} }

View File

@@ -7,6 +7,7 @@ import (
"sync" "sync"
"github.com/hashicorp/go-multierror" "github.com/hashicorp/go-multierror"
"golang.org/x/exp/maps"
"github.com/netbirdio/netbird/client/errors" "github.com/netbirdio/netbird/client/errors"
"github.com/netbirdio/netbird/route" "github.com/netbirdio/netbird/route"
@@ -43,8 +44,8 @@ func (rs *RouteSelector) SelectRoutes(routes []route.NetID, appendRoute bool, al
if rs.selectedRoutes == nil { if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
clear(rs.deselectedRoutes) maps.Clear(rs.deselectedRoutes)
clear(rs.selectedRoutes) maps.Clear(rs.selectedRoutes)
for _, r := range allRoutes { for _, r := range allRoutes {
rs.deselectedRoutes[r] = struct{}{} rs.deselectedRoutes[r] = struct{}{}
} }
@@ -77,8 +78,8 @@ func (rs *RouteSelector) SelectAllRoutes() {
if rs.selectedRoutes == nil { if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
clear(rs.deselectedRoutes) maps.Clear(rs.deselectedRoutes)
clear(rs.selectedRoutes) maps.Clear(rs.selectedRoutes)
} }
// DeselectRoutes removes specific routes from the selection. // DeselectRoutes removes specific routes from the selection.
@@ -115,8 +116,8 @@ func (rs *RouteSelector) DeselectAllRoutes() {
if rs.selectedRoutes == nil { if rs.selectedRoutes == nil {
rs.selectedRoutes = map[route.NetID]struct{}{} rs.selectedRoutes = map[route.NetID]struct{}{}
} }
clear(rs.deselectedRoutes) maps.Clear(rs.deselectedRoutes)
clear(rs.selectedRoutes) maps.Clear(rs.selectedRoutes)
} }
// IsSelected checks if a specific route is selected. // IsSelected checks if a specific route is selected.

View File

@@ -2,358 +2,217 @@
package sleep package sleep
/*
#cgo LDFLAGS: -framework IOKit -framework CoreFoundation
#include <IOKit/pwr_mgt/IOPMLib.h>
#include <IOKit/IOMessage.h>
#include <CoreFoundation/CoreFoundation.h>
extern void sleepCallbackBridge();
extern void poweredOnCallbackBridge();
extern void suspendedCallbackBridge();
extern void resumedCallbackBridge();
// C global variables for IOKit state
static IONotificationPortRef g_notifyPortRef = NULL;
static io_object_t g_notifierObject = 0;
static io_object_t g_generalInterestNotifier = 0;
static io_connect_t g_rootPort = 0;
static CFRunLoopRef g_runLoop = NULL;
static void sleepCallback(void* refCon, io_service_t service, natural_t messageType, void* messageArgument) {
switch (messageType) {
case kIOMessageSystemWillSleep:
sleepCallbackBridge();
IOAllowPowerChange(g_rootPort, (long)messageArgument);
break;
case kIOMessageSystemHasPoweredOn:
poweredOnCallbackBridge();
break;
case kIOMessageServiceIsSuspended:
suspendedCallbackBridge();
break;
case kIOMessageServiceIsResumed:
resumedCallbackBridge();
break;
default:
break;
}
}
static void registerNotifications() {
g_rootPort = IORegisterForSystemPower(
NULL,
&g_notifyPortRef,
(IOServiceInterestCallback)sleepCallback,
&g_notifierObject
);
if (g_rootPort == 0) {
return;
}
CFRunLoopAddSource(CFRunLoopGetCurrent(),
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
g_runLoop = CFRunLoopGetCurrent();
CFRunLoopRun();
}
static void unregisterNotifications() {
CFRunLoopRemoveSource(g_runLoop,
IONotificationPortGetRunLoopSource(g_notifyPortRef),
kCFRunLoopCommonModes);
IODeregisterForSystemPower(&g_notifierObject);
IOServiceClose(g_rootPort);
IONotificationPortDestroy(g_notifyPortRef);
CFRunLoopStop(g_runLoop);
g_notifyPortRef = NULL;
g_notifierObject = 0;
g_rootPort = 0;
g_runLoop = NULL;
}
*/
import "C"
import ( import (
"context"
"fmt" "fmt"
"runtime" "runtime"
"sync" "sync"
"time" "time"
"unsafe"
"github.com/ebitengine/purego"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
) )
// IOKit message types from IOKit/IOMessage.h.
const (
kIOMessageCanSystemSleep uintptr = 0xe0000270
kIOMessageSystemWillSleep uintptr = 0xe0000280
kIOMessageSystemHasPoweredOn uintptr = 0xe0000300
)
var ( var (
ioKit iokitFuncs
cf cfFuncs
cfCommonModes uintptr
libInitOnce sync.Once
libInitErr error
// callbackThunk is the single C-callable trampoline registered with IOKit.
callbackThunk uintptr
serviceRegistry = make(map[*Detector]struct{}) serviceRegistry = make(map[*Detector]struct{})
serviceRegistryMu sync.Mutex serviceRegistryMu sync.Mutex
session *runLoopSession
// lifecycleMu serializes Register/Deregister so a new registration can't
// start a second runloop while a previous teardown is still pending.
lifecycleMu sync.Mutex
) )
// iokitFuncs holds IOKit symbols resolved once at init. //export sleepCallbackBridge
type iokitFuncs struct { func sleepCallbackBridge() {
IORegisterForSystemPower func(refcon uintptr, portRef *uintptr, callback uintptr, notifier *uintptr) uintptr log.Info("sleepCallbackBridge event triggered")
IODeregisterForSystemPower func(notifier *uintptr) int32
IOAllowPowerChange func(kernelPort uintptr, notificationID uintptr) int32
IOServiceClose func(connect uintptr) int32
IONotificationPortGetRunLoopSource func(port uintptr) uintptr
IONotificationPortDestroy func(port uintptr)
}
// cfFuncs holds CoreFoundation symbols resolved once at init.
type cfFuncs struct {
CFRunLoopGetCurrent func() uintptr
CFRunLoopRun func()
CFRunLoopStop func(rl uintptr)
CFRunLoopAddSource func(rl, source, mode uintptr)
CFRunLoopRemoveSource func(rl, source, mode uintptr)
}
// runLoopSession bundles the handles owned by one CFRunLoop lifetime. A nil
// session means no runloop is active and the next Register must start one.
type runLoopSession struct {
rl uintptr
port uintptr
notifier uintptr
rp uintptr
}
// detectorSnapshot pins a detector's callback and done channel so dispatch
// runs with values valid at snapshot time, even if a concurrent
// Deregister/Register rewrites the detector's fields.
type detectorSnapshot struct {
detector *Detector
callback func(event EventType)
done <-chan struct{}
}
// Detector delivers sleep and wake events to a registered callback.
type Detector struct {
callback func(event EventType)
done chan struct{}
}
// Register installs callback for power events. The first registration starts
// the CFRunLoop on a dedicated OS-locked thread and blocks until IOKit
// registration succeeds or fails; subsequent registrations just add to the
// dispatch set.
func (d *Detector) Register(callback func(event EventType)) error {
lifecycleMu.Lock()
defer lifecycleMu.Unlock()
serviceRegistryMu.Lock() serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeSleep)
}
}
//export resumedCallbackBridge
func resumedCallbackBridge() {
log.Info("resumedCallbackBridge event triggered")
}
//export suspendedCallbackBridge
func suspendedCallbackBridge() {
log.Info("suspendedCallbackBridge event triggered")
}
//export poweredOnCallbackBridge
func poweredOnCallbackBridge() {
log.Info("poweredOnCallbackBridge event triggered")
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
for svc := range serviceRegistry {
svc.triggerCallback(EventTypeWakeUp)
}
}
type Detector struct {
callback func(event EventType)
ctx context.Context
cancel context.CancelFunc
}
func NewDetector() (*Detector, error) {
return &Detector{}, nil
}
func (d *Detector) Register(callback func(event EventType)) error {
serviceRegistryMu.Lock()
defer serviceRegistryMu.Unlock()
if _, exists := serviceRegistry[d]; exists { if _, exists := serviceRegistry[d]; exists {
serviceRegistryMu.Unlock()
return fmt.Errorf("detector service already registered") return fmt.Errorf("detector service already registered")
} }
d.callback = callback
d.done = make(chan struct{})
serviceRegistry[d] = struct{}{}
needSetup := session == nil
serviceRegistryMu.Unlock()
if !needSetup { d.callback = callback
d.ctx, d.cancel = context.WithCancel(context.Background())
if len(serviceRegistry) > 0 {
serviceRegistry[d] = struct{}{}
return nil return nil
} }
errCh := make(chan error, 1) serviceRegistry[d] = struct{}{}
go runRunLoop(errCh)
if err := <-errCh; err != nil { // CFRunLoop must run on a single fixed OS thread
serviceRegistryMu.Lock() go func() {
delete(serviceRegistry, d) runtime.LockOSThread()
close(d.done) defer runtime.UnlockOSThread()
d.done = nil
serviceRegistryMu.Unlock() C.registerNotifications()
return err }()
}
log.Info("sleep detection service started on macOS") log.Info("sleep detection service started on macOS")
return nil return nil
} }
// Deregister removes the detector. When the last detector leaves, IOKit // Deregister removes the detector. When the last detector is removed, IOKit registration is torn down
// notifications are torn down and the runloop is stopped. // and the runloop is stopped and cleaned up.
func (d *Detector) Deregister() error { func (d *Detector) Deregister() error {
lifecycleMu.Lock()
defer lifecycleMu.Unlock()
serviceRegistryMu.Lock() serviceRegistryMu.Lock()
if _, exists := serviceRegistry[d]; !exists { defer serviceRegistryMu.Unlock()
serviceRegistryMu.Unlock() _, exists := serviceRegistry[d]
if !exists {
return nil return nil
} }
close(d.done)
// cancel and remove this detector
d.cancel()
delete(serviceRegistry, d) delete(serviceRegistry, d)
// If other Detectors still exist, leave IOKit running
if len(serviceRegistry) > 0 { if len(serviceRegistry) > 0 {
serviceRegistryMu.Unlock()
return nil return nil
} }
sess := session
serviceRegistryMu.Unlock()
log.Info("sleep detection service stopping (deregister)") log.Info("sleep detection service stopping (deregister)")
if sess == nil { // Deregister IOKit notifications, stop runloop, and free resources
return nil C.unregisterNotifications()
}
if sess.rl != 0 && sess.port != 0 {
source := ioKit.IONotificationPortGetRunLoopSource(sess.port)
cf.CFRunLoopRemoveSource(sess.rl, source, cfCommonModes)
}
if sess.notifier != 0 {
n := sess.notifier
ioKit.IODeregisterForSystemPower(&n)
}
// Clear session only after IODeregisterForSystemPower returns so any
// in-flight powerCallback can still look up session.rp to ack sleep.
serviceRegistryMu.Lock()
session = nil
serviceRegistryMu.Unlock()
if sess.rp != 0 {
ioKit.IOServiceClose(sess.rp)
}
if sess.port != 0 {
ioKit.IONotificationPortDestroy(sess.port)
}
if sess.rl != 0 {
cf.CFRunLoopStop(sess.rl)
}
return nil return nil
} }
func (d *Detector) triggerCallback(event EventType, cb func(event EventType), done <-chan struct{}) { func (d *Detector) triggerCallback(event EventType) {
if cb == nil || done == nil {
return
}
select {
case <-done:
return
default:
}
doneChan := make(chan struct{}) doneChan := make(chan struct{})
timeout := time.NewTimer(500 * time.Millisecond) timeout := time.NewTimer(500 * time.Millisecond)
defer timeout.Stop() defer timeout.Stop()
go func() { cb := d.callback
defer close(doneChan) go func(callback func(event EventType)) {
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep callback: %v", r)
}
}()
log.Info("sleep detection event fired") log.Info("sleep detection event fired")
cb(event) callback(event)
}() close(doneChan)
}(cb)
select { select {
case <-doneChan: case <-doneChan:
case <-done: case <-d.ctx.Done():
case <-timeout.C: case <-timeout.C:
log.Warn("sleep callback timed out") log.Warnf("sleep callback timed out")
} }
} }
// NewDetector initializes IOKit/CoreFoundation bindings and returns a Detector.
func NewDetector() (*Detector, error) {
if err := initLibs(); err != nil {
return nil, err
}
return &Detector{}, nil
}
func initLibs() error {
libInitOnce.Do(func() {
iokit, err := purego.Dlopen("/System/Library/Frameworks/IOKit.framework/IOKit", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libInitErr = fmt.Errorf("dlopen IOKit: %w", err)
return
}
cfLib, err := purego.Dlopen("/System/Library/Frameworks/CoreFoundation.framework/CoreFoundation", purego.RTLD_NOW|purego.RTLD_GLOBAL)
if err != nil {
libInitErr = fmt.Errorf("dlopen CoreFoundation: %w", err)
return
}
purego.RegisterLibFunc(&ioKit.IORegisterForSystemPower, iokit, "IORegisterForSystemPower")
purego.RegisterLibFunc(&ioKit.IODeregisterForSystemPower, iokit, "IODeregisterForSystemPower")
purego.RegisterLibFunc(&ioKit.IOAllowPowerChange, iokit, "IOAllowPowerChange")
purego.RegisterLibFunc(&ioKit.IOServiceClose, iokit, "IOServiceClose")
purego.RegisterLibFunc(&ioKit.IONotificationPortGetRunLoopSource, iokit, "IONotificationPortGetRunLoopSource")
purego.RegisterLibFunc(&ioKit.IONotificationPortDestroy, iokit, "IONotificationPortDestroy")
purego.RegisterLibFunc(&cf.CFRunLoopGetCurrent, cfLib, "CFRunLoopGetCurrent")
purego.RegisterLibFunc(&cf.CFRunLoopRun, cfLib, "CFRunLoopRun")
purego.RegisterLibFunc(&cf.CFRunLoopStop, cfLib, "CFRunLoopStop")
purego.RegisterLibFunc(&cf.CFRunLoopAddSource, cfLib, "CFRunLoopAddSource")
purego.RegisterLibFunc(&cf.CFRunLoopRemoveSource, cfLib, "CFRunLoopRemoveSource")
modeAddr, err := purego.Dlsym(cfLib, "kCFRunLoopCommonModes")
if err != nil {
libInitErr = fmt.Errorf("dlsym kCFRunLoopCommonModes: %w", err)
return
}
// Launder the uintptr-to-pointer conversion through a Go variable so
// go vet's unsafeptr analyzer doesn't flag a system-library global.
cfCommonModes = **(**uintptr)(unsafe.Pointer(&modeAddr))
// NewCallback slots are a finite, non-reclaimable resource, so register
// a single thunk that dispatches to the current Detector set.
callbackThunk = purego.NewCallback(powerCallback)
})
return libInitErr
}
// powerCallback is the IOServiceInterestCallback trampoline, invoked on the
// runloop thread. A Go panic crossing the purego boundary has undefined
// behavior, so contain it here.
func powerCallback(refcon, service, messageType, messageArgument uintptr) uintptr {
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep powerCallback: %v", r)
}
}()
switch messageType {
case kIOMessageCanSystemSleep:
// Not acknowledging forces a 30s IOKit timeout before idle sleep.
allowPowerChange(messageArgument)
case kIOMessageSystemWillSleep:
dispatchEvent(EventTypeSleep)
allowPowerChange(messageArgument)
case kIOMessageSystemHasPoweredOn:
dispatchEvent(EventTypeWakeUp)
}
return 0
}
func allowPowerChange(messageArgument uintptr) {
serviceRegistryMu.Lock()
var port uintptr
if session != nil {
port = session.rp
}
serviceRegistryMu.Unlock()
if port != 0 {
ioKit.IOAllowPowerChange(port, messageArgument)
}
}
func dispatchEvent(event EventType) {
serviceRegistryMu.Lock()
snaps := make([]detectorSnapshot, 0, len(serviceRegistry))
for d := range serviceRegistry {
snaps = append(snaps, detectorSnapshot{
detector: d,
callback: d.callback,
done: d.done,
})
}
serviceRegistryMu.Unlock()
for _, s := range snaps {
s.detector.triggerCallback(event, s.callback, s.done)
}
}
// runRunLoop owns the OS-locked thread that CFRunLoop is pinned to. Setup
// result is reported on errCh so Register can surface failures synchronously.
func runRunLoop(errCh chan<- error) {
runtime.LockOSThread()
defer runtime.UnlockOSThread()
sess, err := setupSession()
if err == nil {
serviceRegistryMu.Lock()
session = sess
serviceRegistryMu.Unlock()
}
errCh <- err
if err != nil {
return
}
defer func() {
if r := recover(); r != nil {
log.Errorf("panic in sleep runloop: %v", r)
}
}()
cf.CFRunLoopRun()
}
// setupSession performs the IOKit registration on the current thread. Panics
// are converted to errors so runRunLoop never leaves errCh unsent.
func setupSession() (s *runLoopSession, err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("panic during runloop setup: %v", r)
}
}()
var portRef, notifier uintptr
rp := ioKit.IORegisterForSystemPower(0, &portRef, callbackThunk, &notifier)
if rp == 0 {
return nil, fmt.Errorf("IORegisterForSystemPower returned zero")
}
rl := cf.CFRunLoopGetCurrent()
source := ioKit.IONotificationPortGetRunLoopSource(portRef)
cf.CFRunLoopAddSource(rl, source, cfCommonModes)
return &runLoopSession{rl: rl, port: portRef, notifier: notifier, rp: rp}, nil
}

View File

@@ -6,6 +6,7 @@ import (
"fmt" "fmt"
"net" "net"
"runtime" "runtime"
"time"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
@@ -27,10 +28,6 @@ func NewWGIfaceMonitor() *WGIfaceMonitor {
// Start begins monitoring the WireGuard interface. // Start begins monitoring the WireGuard interface.
// It relies on the provided context cancellation to stop. // It relies on the provided context cancellation to stop.
//
// On Linux the watcher is event-driven (RTNLGRP_LINK netlink subscription)
// to avoid the allocation churn of repeatedly dumping the kernel link
// table; on other platforms it falls back to a low-frequency poll.
func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) { func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRestart bool, err error) {
defer close(m.done) defer close(m.done)
@@ -59,7 +56,31 @@ func (m *WGIfaceMonitor) Start(ctx context.Context, ifaceName string) (shouldRes
log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex) log.Infof("Interface monitor: watching %s (index: %d)", ifaceName, expectedIndex)
return watchInterface(ctx, ifaceName, expectedIndex) ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %v", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
} }
// getInterfaceIndex returns the index of a network interface by name. // getInterfaceIndex returns the index of a network interface by name.

View File

@@ -1,134 +0,0 @@
//go:build linux
package internal
import (
"context"
"fmt"
"syscall"
log "github.com/sirupsen/logrus"
"github.com/vishvananda/netlink"
)
// watchInterface uses an RTNLGRP_LINK netlink subscription to detect
// deletion or recreation of the WireGuard interface.
//
// The previous implementation polled net.InterfaceByName every 2 s, which
// on Linux issues syscall.NetlinkRIB(RTM_GETLINK, ...) and dumps the
// entire kernel link table on every call. On hosts with many veth
// interfaces (containers, bridges) the resulting allocation churn was on
// the order of ~1 GB/day from this single ticker, which on small ARM
// hosts manifested as a slow RSS climb (see netbirdio/netbird#3678).
//
// The event-driven version below allocates only when the kernel actually
// publishes a link event for the tracked interface — typically zero
// allocations between events.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
done := make(chan struct{})
defer close(done)
// Buffer the channel to absorb event bursts (e.g. when many veth
// pairs are created/destroyed at once by container runtimes).
linkChan := make(chan netlink.LinkUpdate, 32)
if err := netlink.LinkSubscribe(linkChan, done); err != nil {
// Return shouldRestart=true so the engine recovers monitoring
// via triggerClientRestart instead of silently losing it for
// the rest of the process lifetime.
return true, fmt.Errorf("subscribe to link updates: %w", err)
}
// Race window: the interface could have been deleted (or recreated)
// between the initial getInterfaceIndex() in Start and LinkSubscribe
// completing its handshake with the kernel. Re-check explicitly so we
// do not block forever waiting for an event that already fired.
if currentIndex, err := getInterfaceIndex(ifaceName); err != nil {
log.Infof("Interface monitor: %s deleted before subscription completed", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
} else if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d) before subscription completed",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case update, ok := <-linkChan:
if !ok {
// The vishvananda/netlink subscription goroutine closes
// the channel on receive errors. Signal the engine to
// restart so monitoring is re-established instead of
// silently ending.
log.Warnf("Interface monitor: link subscription channel closed unexpectedly for %s", ifaceName)
return true, fmt.Errorf("link subscription channel closed unexpectedly")
}
if restart, err := inspectLinkEvent(update, ifaceName, expectedIndex); restart {
return true, err
}
}
}
}
// inspectLinkEvent classifies a single netlink link update against the
// tracked WireGuard interface. It returns (true, err) when the engine
// should restart monitoring; (false, nil) means the event is unrelated
// and the caller should keep waiting.
//
// The error component, when non-nil, describes the kernel-side reason
// (deletion or rename); the recreation case returns (true, nil) since
// no error condition is reported.
func inspectLinkEvent(update netlink.LinkUpdate, ifaceName string, expectedIndex int) (bool, error) {
eventIndex := int(update.Index)
eventName := ""
if attrs := update.Attrs(); attrs != nil {
eventName = attrs.Name
}
switch update.Header.Type {
case syscall.RTM_DELLINK:
return inspectDelLink(eventIndex, ifaceName, expectedIndex)
case syscall.RTM_NEWLINK:
return inspectNewLink(eventIndex, eventName, ifaceName, expectedIndex)
}
return false, nil
}
// inspectDelLink reports a restart when an RTM_DELLINK arrives for the
// tracked interface index.
func inspectDelLink(eventIndex int, ifaceName string, expectedIndex int) (bool, error) {
if eventIndex != expectedIndex {
return false, nil
}
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted", ifaceName)
}
// inspectNewLink reports a restart when an RTM_NEWLINK either:
//
// 1. Introduces a link with our name at a different index (recreation
// after a delete), or
//
// 2. Reports a link still at our index but with a different name
// (in-place rename). The previous polling implementation caught
// this implicitly because net.InterfaceByName(ifaceName) would
// start failing; the event-driven version has to test it.
//
// Same name + same index is just a flag/state change on the existing
// interface and is ignored.
func inspectNewLink(eventIndex int, eventName, ifaceName string, expectedIndex int) (bool, error) {
if eventName == ifaceName && eventIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, eventIndex)
return true, nil
}
if eventIndex == expectedIndex && eventName != "" && eventName != ifaceName {
log.Infof("Interface monitor: %s renamed to %s (index %d), restarting engine",
ifaceName, eventName, expectedIndex)
return true, fmt.Errorf("interface %s renamed to %s", ifaceName, eventName)
}
return false, nil
}

View File

@@ -1,56 +0,0 @@
//go:build !linux
package internal
import (
"context"
"fmt"
"time"
log "github.com/sirupsen/logrus"
)
// watchInterface polls net.InterfaceByName at a fixed interval to detect
// deletion or recreation of the WireGuard interface.
//
// This is the fallback used on non-Linux desktop and server platforms
// (darwin, windows, freebsd). It is also compiled on android and ios so
// the package builds on every supported GOOS, but it is never reached
// at runtime there because Start() in wg_iface_monitor.go exits early
// on mobile platforms.
//
// The Linux build (see wg_iface_monitor_linux.go) uses an event-driven
// RTNLGRP_LINK netlink subscription instead, because on Linux
// net.InterfaceByName issues syscall.NetlinkRIB(RTM_GETLINK, ...) which
// dumps the entire kernel link table on every call and produces
// significant allocation churn (netbirdio/netbird#3678).
//
// Windows is also reported in #3678 as affected by RSS climb. A future
// follow-up could implement an event-driven watcher there using
// NotifyIpInterfaceChange from iphlpapi.
func watchInterface(ctx context.Context, ifaceName string, expectedIndex int) (bool, error) {
ticker := time.NewTicker(2 * time.Second)
defer ticker.Stop()
for {
select {
case <-ctx.Done():
log.Infof("Interface monitor: stopped for %s", ifaceName)
return false, fmt.Errorf("wg interface monitor stopped: %w", ctx.Err())
case <-ticker.C:
currentIndex, err := getInterfaceIndex(ifaceName)
if err != nil {
// Interface was deleted
log.Infof("Interface monitor: %s deleted", ifaceName)
return true, fmt.Errorf("interface %s deleted: %w", ifaceName, err)
}
// Check if interface index changed (interface was recreated)
if currentIndex != expectedIndex {
log.Infof("Interface monitor: %s recreated (index changed from %d to %d), restarting engine",
ifaceName, expectedIndex, currentIndex)
return true, nil
}
}
}
}

View File

@@ -13,25 +13,15 @@
<MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/> <MajorUpgrade AllowSameVersionUpgrades='yes' DowngradeErrorMessage="A newer version of [ProductName] is already installed. Setup will now exit."/>
<!-- Autostart: enabled by default, disable with AUTOSTART=0 on the msiexec command line -->
<Property Id="AUTOSTART" Value="1" />
<StandardDirectory Id="ProgramFiles64Folder"> <StandardDirectory Id="ProgramFiles64Folder">
<Directory Id="NetbirdInstallDir" Name="Netbird"> <Directory Id="NetbirdInstallDir" Name="Netbird">
<Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64"> <Component Id="NetbirdFiles" Guid="db3165de-cc6e-4922-8396-9d892950e23e" Bitness="always64">
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" /> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird.exe" KeyPath="yes" />
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe"> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\netbird-ui.exe">
<Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon"> <Shortcut Id="NetbirdDesktopShortcut" Directory="DesktopFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" /> <Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon" />
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
</Shortcut>
<Shortcut Id="NetbirdStartMenuShortcut" Directory="StartMenuFolder" Name="NetBird" WorkingDirectory="NetbirdInstallDir" Icon="NetbirdIcon">
<ShortcutProperty Key="System.AppUserModel.ID" Value="NetBird" />
<ShortcutProperty Key="System.AppUserModel.ToastActivatorCLSID" Value="{0E1B4DE7-E148-432B-9814-544F941826EC}" />
</Shortcut>
</File> </File>
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" /> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\wintun.dll" />
<File Id="NetbirdToastIcon" Name="netbird.png" Source=".\client\ui\assets\netbird.png" />
<?if $(var.ArchSuffix) = "amd64" ?> <?if $(var.ArchSuffix) = "amd64" ?>
<File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" /> <File ProcessorArchitecture="$(var.ProcessorArchitecture)" Source=".\dist\netbird_windows_$(var.ArchSuffix)\opengl32.dll" />
<?endif ?> <?endif ?>
@@ -56,30 +46,8 @@
</Directory> </Directory>
</StandardDirectory> </StandardDirectory>
<!-- Per-user component: HKCU keypath (auto GUID via "*"), separate from
the per-machine NetbirdFiles component to satisfy ICE57. -->
<StandardDirectory Id="ProgramMenuFolder">
<Component Id="NetbirdAumidRegistry" Guid="*">
<RegistryKey Root="HKCU" Key="Software\Classes\AppUserModelId\NetBird" ForceDeleteOnUninstall="yes">
<RegistryValue Name="InstalledByMSI" Type="integer" Value="1" KeyPath="yes" />
</RegistryKey>
</Component>
</StandardDirectory>
<StandardDirectory Id="CommonAppDataFolder">
<Directory Id="NetbirdAutoStartDir" Name="Netbird">
<Component Id="NetbirdAutoStart" Guid="b199eaca-b0dd-4032-af19-679cfad48eb3" Bitness="always64" Condition='AUTOSTART = "1"'>
<RegistryValue Root="HKLM" Key="Software\Microsoft\Windows\CurrentVersion\Run"
Name="Netbird" Value="&quot;[NetbirdInstallDir]netbird-ui.exe&quot;"
Type="string" KeyPath="yes" />
</Component>
</Directory>
</StandardDirectory>
<ComponentGroup Id="NetbirdFilesComponent"> <ComponentGroup Id="NetbirdFilesComponent">
<ComponentRef Id="NetbirdFiles" /> <ComponentRef Id="NetbirdFiles" />
<ComponentRef Id="NetbirdAumidRegistry" />
<ComponentRef Id="NetbirdAutoStart" />
</ComponentGroup> </ComponentGroup>
<util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" /> <util:CloseApplication Id="CloseNetBird" CloseMessage="no" Target="netbird.exe" RebootPrompt="no" />

File diff suppressed because it is too large Load Diff

View File

@@ -64,17 +64,6 @@ service DaemonService {
rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {} rpc TracePacket(TracePacketRequest) returns (TracePacketResponse) {}
// StartCapture begins streaming packet capture on the WireGuard interface.
// Requires --enable-capture set at service install/reconfigure time.
rpc StartCapture(StartCaptureRequest) returns (stream CapturePacket) {}
// StartBundleCapture begins capturing packets to a server-side temp file
// for inclusion in the next debug bundle. Auto-stops after the given timeout.
rpc StartBundleCapture(StartBundleCaptureRequest) returns (StartBundleCaptureResponse) {}
// StopBundleCapture stops the running bundle capture. Idempotent.
rpc StopBundleCapture(StopBundleCaptureRequest) returns (StopBundleCaptureResponse) {}
rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {} rpc SubscribeEvents(SubscribeRequest) returns (stream SystemEvent) {}
rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {} rpc GetEvents(GetEventsRequest) returns (GetEventsResponse) {}
@@ -115,6 +104,8 @@ service DaemonService {
// StopCPUProfile stops CPU profiling in the daemon // StopCPUProfile stops CPU profiling in the daemon
rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {} rpc StopCPUProfile(StopCPUProfileRequest) returns (StopCPUProfileResponse) {}
rpc NotifyOSLifecycle(OSLifecycleRequest) returns(OSLifecycleResponse) {}
rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {} rpc GetInstallerResult(InstallerResultRequest) returns (InstallerResultResponse) {}
// ExposeService exposes a local port via the NetBird reverse proxy // ExposeService exposes a local port via the NetBird reverse proxy
@@ -123,6 +114,20 @@ service DaemonService {
message OSLifecycleRequest {
// avoid collision with loglevel enum
enum CycleType {
UNKNOWN = 0;
SLEEP = 1;
WAKEUP = 2;
}
CycleType type = 1;
}
message OSLifecycleResponse {}
message LoginRequest { message LoginRequest {
// setupKey netbird setup key. // setupKey netbird setup key.
string setupKey = 1; string setupKey = 1;
@@ -843,26 +848,3 @@ message ExposeServiceReady {
string domain = 3; string domain = 3;
bool port_auto_assigned = 4; bool port_auto_assigned = 4;
} }
message StartCaptureRequest {
bool text_output = 1;
uint32 snap_len = 2;
google.protobuf.Duration duration = 3;
string filter_expr = 4;
bool verbose = 5;
bool ascii = 6;
}
message CapturePacket {
bytes data = 1;
}
message StartBundleCaptureRequest {
// timeout auto-stops the capture after this duration.
// Clamped to a server-side maximum (10 minutes). Zero or unset defaults to the maximum.
google.protobuf.Duration timeout = 1;
}
message StartBundleCaptureResponse {}
message StopBundleCaptureRequest {}
message StopBundleCaptureResponse {}

File diff suppressed because it is too large Load Diff

View File

@@ -1,365 +0,0 @@
package server
import (
"context"
"io"
"os"
"sync"
"time"
log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util/capture"
)
const maxBundleCaptureDuration = 10 * time.Minute
// bundleCapture holds the state of an in-progress capture destined for the
// debug bundle. The lifecycle is:
//
// StartBundleCapture → capture running, writing to temp file
// StopBundleCapture → capture stopped, temp file available
// DebugBundle → temp file included in zip, then cleaned up
type bundleCapture struct {
mu sync.Mutex
sess *capture.Session
file *os.File
engine *internal.Engine
cancel context.CancelFunc
stopped bool
}
// stop halts the capture session and closes the pcap writer. Idempotent.
func (bc *bundleCapture) stop() {
bc.mu.Lock()
defer bc.mu.Unlock()
if bc.stopped {
return
}
bc.stopped = true
if bc.cancel != nil {
bc.cancel()
}
if bc.sess != nil {
bc.sess.Stop()
}
}
// path returns the temp file path, or "" if no file exists.
func (bc *bundleCapture) path() string {
if bc.file == nil {
return ""
}
return bc.file.Name()
}
// cleanup removes the temp file.
func (bc *bundleCapture) cleanup() {
if bc.file == nil {
return
}
name := bc.file.Name()
if err := bc.file.Close(); err != nil {
log.Debugf("close bundle capture file: %v", err)
}
if err := os.Remove(name); err != nil && !os.IsNotExist(err) {
log.Debugf("remove bundle capture file: %v", err)
}
bc.file = nil
}
// StartCapture streams a pcap or text packet capture over gRPC.
// Gated by the --enable-capture service flag.
func (s *Server) StartCapture(req *proto.StartCaptureRequest, stream proto.DaemonService_StartCaptureServer) error {
if !s.captureEnabled {
return status.Error(codes.PermissionDenied,
"packet capture is disabled; reinstall or reconfigure the service with --enable-capture")
}
if d := req.GetDuration(); d != nil && d.AsDuration() < 0 {
return status.Error(codes.InvalidArgument, "duration must not be negative")
}
matcher, err := parseCaptureFilter(req)
if err != nil {
return status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
pr, pw := io.Pipe()
opts := capture.Options{
Matcher: matcher,
SnapLen: req.GetSnapLen(),
Verbose: req.GetVerbose(),
ASCII: req.GetAscii(),
}
if req.GetTextOutput() {
opts.TextOutput = pw
} else {
opts.Output = pw
}
sess, err := capture.NewSession(opts)
if err != nil {
pw.Close()
return status.Errorf(codes.Internal, "create capture session: %v", err)
}
engine, err := s.claimCapture(sess)
if err != nil {
sess.Stop()
pw.Close()
return err
}
if err := engine.SetCapture(sess); err != nil {
s.releaseCapture(sess)
sess.Stop()
pw.Close()
return status.Errorf(codes.Internal, "set capture: %v", err)
}
// Send an empty initial message to signal that the capture was accepted.
// The client waits for this before printing the banner, so it must arrive
// before any packet data.
if err := stream.Send(&proto.CapturePacket{}); err != nil {
s.clearCaptureIfOwner(sess, engine)
sess.Stop()
pw.Close()
return status.Errorf(codes.Internal, "send initial message: %v", err)
}
ctx := stream.Context()
if d := req.GetDuration(); d != nil {
if dur := d.AsDuration(); dur > 0 {
var cancel context.CancelFunc
ctx, cancel = context.WithTimeout(ctx, dur)
defer cancel()
}
}
go func() {
<-ctx.Done()
s.clearCaptureIfOwner(sess, engine)
sess.Stop()
pw.Close()
}()
defer pr.Close()
log.Infof("packet capture started (text=%v, expr=%q)", req.GetTextOutput(), req.GetFilterExpr())
defer func() {
stats := sess.Stats()
log.Infof("packet capture stopped: %d packets, %d bytes, %d dropped",
stats.Packets, stats.Bytes, stats.Dropped)
}()
return streamToGRPC(pr, stream)
}
func streamToGRPC(r io.Reader, stream proto.DaemonService_StartCaptureServer) error {
buf := make([]byte, 32*1024)
for {
n, readErr := r.Read(buf)
if n > 0 {
if err := stream.Send(&proto.CapturePacket{Data: buf[:n]}); err != nil {
log.Debugf("capture stream send: %v", err)
return nil //nolint:nilerr // client disconnected
}
}
if readErr != nil {
return nil //nolint:nilerr // pipe closed, capture stopped normally
}
}
}
// StartBundleCapture begins capturing packets to a server-side temp file for
// inclusion in the next debug bundle. Not gated by --enable-capture since the
// output stays on the server (same trust level as CPU profiling).
//
// A timeout auto-stops the capture as a safety net if StopBundleCapture is
// never called (e.g. CLI crash).
func (s *Server) StartBundleCapture(_ context.Context, req *proto.StartBundleCaptureRequest) (*proto.StartBundleCaptureResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.stopBundleCaptureLocked()
s.cleanupBundleCapture()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
engine, err := s.getCaptureEngineLocked()
if err != nil {
// Not fatal: kernel mode or not connected. Log and return success
// so the debug bundle still generates without capture data.
log.Warnf("packet capture unavailable, skipping: %v", err)
return &proto.StartBundleCaptureResponse{}, nil
}
timeout := req.GetTimeout().AsDuration()
if timeout <= 0 || timeout > maxBundleCaptureDuration {
timeout = maxBundleCaptureDuration
}
f, err := os.CreateTemp("", "netbird.capture.*.pcap")
if err != nil {
return nil, status.Errorf(codes.Internal, "create temp file: %v", err)
}
sess, err := capture.NewSession(capture.Options{Output: f})
if err != nil {
f.Close()
os.Remove(f.Name())
return nil, status.Errorf(codes.Internal, "create capture session: %v", err)
}
if err := engine.SetCapture(sess); err != nil {
sess.Stop()
f.Close()
os.Remove(f.Name())
log.Warnf("packet capture unavailable (no filtered device), skipping: %v", err)
return &proto.StartBundleCaptureResponse{}, nil
}
s.activeCapture = sess
ctx, cancel := context.WithTimeout(context.Background(), timeout)
bc := &bundleCapture{
sess: sess,
file: f,
engine: engine,
cancel: cancel,
}
s.bundleCapture = bc
go func() {
<-ctx.Done()
s.mutex.Lock()
if s.bundleCapture == bc {
s.stopBundleCaptureLocked()
} else {
bc.stop()
}
s.mutex.Unlock()
log.Infof("bundle capture auto-stopped after timeout")
}()
log.Infof("bundle capture started (timeout=%s, file=%s)", timeout, f.Name())
return &proto.StartBundleCaptureResponse{}, nil
}
// StopBundleCapture stops the running bundle capture. Idempotent.
func (s *Server) StopBundleCapture(_ context.Context, _ *proto.StopBundleCaptureRequest) (*proto.StopBundleCaptureResponse, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
s.stopBundleCaptureLocked()
return &proto.StopBundleCaptureResponse{}, nil
}
// stopBundleCaptureLocked stops the bundle capture if running. Must hold s.mutex.
func (s *Server) stopBundleCaptureLocked() {
if s.bundleCapture == nil {
return
}
bc := s.bundleCapture
if bc.engine != nil && s.activeCapture == bc.sess {
if err := bc.engine.SetCapture(nil); err != nil {
log.Debugf("clear bundle capture: %v", err)
}
s.activeCapture = nil
}
bc.stop()
stats := bc.sess.Stats()
log.Infof("bundle capture stopped: %d packets, %d bytes, %d dropped",
stats.Packets, stats.Bytes, stats.Dropped)
}
// bundleCapturePath returns the temp file path if a capture has been taken,
// stops any running capture, and returns "". Called from DebugBundle.
// Must hold s.mutex.
func (s *Server) bundleCapturePath() string {
if s.bundleCapture == nil {
return ""
}
s.bundleCapture.stop()
return s.bundleCapture.path()
}
// cleanupBundleCapture removes the temp file and clears state. Must hold s.mutex.
func (s *Server) cleanupBundleCapture() {
if s.bundleCapture == nil {
return
}
s.bundleCapture.cleanup()
s.bundleCapture = nil
}
// claimCapture reserves the engine's capture slot for sess. Returns
// FailedPrecondition if another capture is already active.
func (s *Server) claimCapture(sess *capture.Session) (*internal.Engine, error) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.activeCapture != nil {
return nil, status.Error(codes.FailedPrecondition, "another capture is already running")
}
engine, err := s.getCaptureEngineLocked()
if err != nil {
return nil, err
}
s.activeCapture = sess
return engine, nil
}
// releaseCapture clears the active-capture owner if it still matches sess.
func (s *Server) releaseCapture(sess *capture.Session) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.activeCapture == sess {
s.activeCapture = nil
}
}
// clearCaptureIfOwner clears engine's capture slot only if sess still owns it.
func (s *Server) clearCaptureIfOwner(sess *capture.Session, engine *internal.Engine) {
s.mutex.Lock()
defer s.mutex.Unlock()
if s.activeCapture != sess {
return
}
if err := engine.SetCapture(nil); err != nil {
log.Debugf("clear capture: %v", err)
}
s.activeCapture = nil
}
func (s *Server) getCaptureEngineLocked() (*internal.Engine, error) {
if s.connectClient == nil {
return nil, status.Error(codes.FailedPrecondition, "client not connected")
}
engine := s.connectClient.Engine()
if engine == nil {
return nil, status.Error(codes.FailedPrecondition, "engine not initialized")
}
return engine, nil
}
// parseCaptureFilter returns a Matcher from the request.
// Returns nil (match all) when no filter expression is set.
func parseCaptureFilter(req *proto.StartCaptureRequest) (capture.Matcher, error) {
expr := req.GetFilterExpr()
if expr == "" {
return nil, nil //nolint:nilnil // nil Matcher means "match all"
}
return capture.ParseFilter(expr)
}

View File

@@ -43,9 +43,7 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
}() }()
} }
capturePath := s.bundleCapturePath() // Prepare refresh callback for health probes
defer s.cleanupBundleCapture()
var refreshStatus func() var refreshStatus func()
if s.connectClient != nil { if s.connectClient != nil {
engine := s.connectClient.Engine() engine := s.connectClient.Engine()
@@ -64,7 +62,6 @@ func (s *Server) DebugBundle(_ context.Context, req *proto.DebugBundleRequest) (
SyncResponse: syncResponse, SyncResponse: syncResponse,
LogPath: s.logFile, LogPath: s.logFile,
CPUProfile: cpuProfileData, CPUProfile: cpuProfileData,
CapturePath: capturePath,
RefreshStatus: refreshStatus, RefreshStatus: refreshStatus,
ClientMetrics: clientMetrics, ClientMetrics: clientMetrics,
}, },

View File

@@ -33,7 +33,6 @@ import (
"github.com/netbirdio/netbird/client/internal/statemanager" "github.com/netbirdio/netbird/client/internal/statemanager"
"github.com/netbirdio/netbird/client/internal/updater" "github.com/netbirdio/netbird/client/internal/updater"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/util/capture"
"github.com/netbirdio/netbird/version" "github.com/netbirdio/netbird/version"
) )
@@ -90,11 +89,7 @@ type Server struct {
profileManager *profilemanager.ServiceManager profileManager *profilemanager.ServiceManager
profilesDisabled bool profilesDisabled bool
updateSettingsDisabled bool updateSettingsDisabled bool
captureEnabled bool networksDisabled bool
bundleCapture *bundleCapture
// activeCapture is the session currently installed on the engine; guarded by s.mutex.
activeCapture *capture.Session
networksDisabled bool
sleepHandler *sleephandler.SleepHandler sleepHandler *sleephandler.SleepHandler
@@ -111,7 +106,7 @@ type oauthAuthFlow struct {
} }
// New server instance constructor. // New server instance constructor.
func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, captureEnabled bool, networksDisabled bool) *Server { func New(ctx context.Context, logFile string, configFile string, profilesDisabled bool, updateSettingsDisabled bool, networksDisabled bool) *Server {
s := &Server{ s := &Server{
rootCtx: ctx, rootCtx: ctx,
logFile: logFile, logFile: logFile,
@@ -120,13 +115,11 @@ func New(ctx context.Context, logFile string, configFile string, profilesDisable
profileManager: profilemanager.NewServiceManager(configFile), profileManager: profilemanager.NewServiceManager(configFile),
profilesDisabled: profilesDisabled, profilesDisabled: profilesDisabled,
updateSettingsDisabled: updateSettingsDisabled, updateSettingsDisabled: updateSettingsDisabled,
captureEnabled: captureEnabled,
networksDisabled: networksDisabled, networksDisabled: networksDisabled,
jwtCache: newJWTCache(), jwtCache: newJWTCache(),
} }
agent := &serverAgent{s} agent := &serverAgent{s}
s.sleepHandler = sleephandler.New(agent) s.sleepHandler = sleephandler.New(agent)
s.startSleepDetector()
return s return s
} }

View File

@@ -104,7 +104,7 @@ func TestConnectWithRetryRuns(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err) t.Fatalf("failed to set active profile state: %v", err)
} }
s := New(ctx, "debug", "", false, false, false, false) s := New(ctx, "debug", "", false, false, false)
s.config = config s.config = config
@@ -165,7 +165,7 @@ func TestServer_Up(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err) t.Fatalf("failed to set active profile state: %v", err)
} }
s := New(ctx, "console", "", false, false, false, false) s := New(ctx, "console", "", false, false, false)
err = s.Start() err = s.Start()
require.NoError(t, err) require.NoError(t, err)
@@ -235,7 +235,7 @@ func TestServer_SubcribeEvents(t *testing.T) {
t.Fatalf("failed to set active profile state: %v", err) t.Fatalf("failed to set active profile state: %v", err)
} }
s := New(ctx, "console", "", false, false, false, false) s := New(ctx, "console", "", false, false, false)
err = s.Start() err = s.Start()
require.NoError(t, err) require.NoError(t, err)
@@ -335,7 +335,7 @@ func startManagement(t *testing.T, signalAddr string, counter *int) (*grpc.Serve
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }
mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil, nil) mgmtServer, err := nbgrpc.NewServer(config, accountManager, settingsMockManager, jobManager, secretsManager, nil, nil, &server.MockIntegratedValidator{}, networkMapController, nil)
if err != nil { if err != nil {
return nil, "", err return nil, "", err
} }

View File

@@ -53,7 +53,7 @@ func TestSetConfig_AllFieldsSaved(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
ctx := context.Background() ctx := context.Background()
s := New(ctx, "console", "", false, false, false, false) s := New(ctx, "console", "", false, false, false)
rosenpassEnabled := true rosenpassEnabled := true
rosenpassPermissive := true rosenpassPermissive := true

View File

@@ -2,18 +2,13 @@ package server
import ( import (
"context" "context"
"os"
"strconv"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/sleep"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
) )
const envDisableSleepDetector = "NB_DISABLE_SLEEP_DETECTOR"
// serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces // serverAgent adapts Server to the handler.Agent and handler.StatusChecker interfaces
type serverAgent struct { type serverAgent struct {
s *Server s *Server
@@ -33,61 +28,19 @@ func (a *serverAgent) Status() (internal.StatusType, error) {
return internal.CtxGetState(a.s.rootCtx).Status() return internal.CtxGetState(a.s.rootCtx).Status()
} }
// startSleepDetector starts the OS sleep/wake detector and forwards events to // NotifyOSLifecycle handles operating system lifecycle events by executing appropriate logic based on the request type.
// the sleep handler. On platforms without a supported detector the attempt func (s *Server) NotifyOSLifecycle(callerCtx context.Context, req *proto.OSLifecycleRequest) (*proto.OSLifecycleResponse, error) {
// logs a warning and returns. Setting NB_DISABLE_SLEEP_DETECTOR=true skips switch req.GetType() {
// registration entirely. case proto.OSLifecycleRequest_WAKEUP:
func (s *Server) startSleepDetector() { if err := s.sleepHandler.HandleWakeUp(callerCtx); err != nil {
if sleepDetectorDisabled() { return &proto.OSLifecycleResponse{}, err
log.Info("sleep detection disabled via " + envDisableSleepDetector)
return
}
svc, err := sleep.New()
if err != nil {
log.Warnf("failed to initialize sleep detection: %v", err)
return
}
err = svc.Register(func(event sleep.EventType) {
switch event {
case sleep.EventTypeSleep:
log.Info("handling sleep event")
if err := s.sleepHandler.HandleSleep(s.rootCtx); err != nil {
log.Errorf("failed to handle sleep event: %v", err)
}
case sleep.EventTypeWakeUp:
log.Info("handling wakeup event")
if err := s.sleepHandler.HandleWakeUp(s.rootCtx); err != nil {
log.Errorf("failed to handle wakeup event: %v", err)
}
} }
}) case proto.OSLifecycleRequest_SLEEP:
if err != nil { if err := s.sleepHandler.HandleSleep(callerCtx); err != nil {
log.Errorf("failed to register sleep detector: %v", err) return &proto.OSLifecycleResponse{}, err
return
}
log.Info("sleep detection service initialized")
go func() {
<-s.rootCtx.Done()
log.Info("stopping sleep event listener")
if err := svc.Deregister(); err != nil {
log.Errorf("failed to deregister sleep detector: %v", err)
} }
}() default:
} log.Errorf("unknown OSLifecycleRequest type: %v", req.GetType())
func sleepDetectorDisabled() bool {
val := os.Getenv(envDisableSleepDetector)
if val == "" {
return false
} }
disabled, err := strconv.ParseBool(val) return &proto.OSLifecycleResponse{}, nil
if err != nil {
log.Warnf("failed to parse %s=%q: %v", envDisableSleepDetector, val, err)
return false
}
return disabled
} }

View File

@@ -224,20 +224,15 @@ func (m *Manager) buildHostPatterns(peer PeerSSHInfo) []string {
func (m *Manager) writeSSHConfig(sshConfig string) error { func (m *Manager) writeSSHConfig(sshConfig string) error {
sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile) sshConfigPath := filepath.Join(m.sshConfigDir, m.sshConfigFile)
sshConfigPathTmp := sshConfigPath + ".tmp"
if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil { if err := os.MkdirAll(m.sshConfigDir, 0755); err != nil {
return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err) return fmt.Errorf("create SSH config directory %s: %w", m.sshConfigDir, err)
} }
if err := writeFileWithTimeout(sshConfigPathTmp, []byte(sshConfig), 0644); err != nil { if err := writeFileWithTimeout(sshConfigPath, []byte(sshConfig), 0644); err != nil {
return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err) return fmt.Errorf("write SSH config file %s: %w", sshConfigPath, err)
} }
if err := os.Rename(sshConfigPathTmp, sshConfigPath); err != nil {
return fmt.Errorf("rename ssh config %s -> %s: %w", sshConfigPathTmp, sshConfigPath, err)
}
log.Infof("Created NetBird SSH client config: %s", sshConfigPath) log.Infof("Created NetBird SSH client config: %s", sshConfigPath)
return nil return nil
} }

View File

@@ -38,10 +38,10 @@ import (
"github.com/netbirdio/netbird/client/iface" "github.com/netbirdio/netbird/client/iface"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/internal/profilemanager" "github.com/netbirdio/netbird/client/internal/profilemanager"
"github.com/netbirdio/netbird/client/internal/sleep"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
"github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/desktop"
"github.com/netbirdio/netbird/client/ui/event" "github.com/netbirdio/netbird/client/ui/event"
"github.com/netbirdio/netbird/client/ui/notifier"
"github.com/netbirdio/netbird/client/ui/process" "github.com/netbirdio/netbird/client/ui/process"
"github.com/netbirdio/netbird/util" "github.com/netbirdio/netbird/util"
@@ -260,7 +260,6 @@ type serviceClient struct {
// application with main windows. // application with main windows.
app fyne.App app fyne.App
notifier notifier.Notifier
wSettings fyne.Window wSettings fyne.Window
showAdvancedSettings bool showAdvancedSettings bool
sendNotification bool sendNotification bool
@@ -365,7 +364,6 @@ func newServiceClient(args *newServiceClientArgs) *serviceClient {
cancel: cancel, cancel: cancel,
addr: args.addr, addr: args.addr,
app: args.app, app: args.app,
notifier: notifier.New(args.app),
logFile: args.logFile, logFile: args.logFile,
sendNotification: false, sendNotification: false,
@@ -894,7 +892,7 @@ func (s *serviceClient) updateStatus() error {
if err != nil { if err != nil {
log.Errorf("get service status: %v", err) log.Errorf("get service status: %v", err)
if s.connected { if s.connected {
s.notifier.Send("Error", "Connection to service lost") s.app.SendNotification(fyne.NewNotification("Error", "Connection to service lost"))
} }
s.setDisconnectedStatus() s.setDisconnectedStatus()
return err return err
@@ -1111,7 +1109,7 @@ func (s *serviceClient) onTrayReady() {
} }
}() }()
s.eventManager = event.NewManager(s.notifier, s.addr) s.eventManager = event.NewManager(s.app, s.addr)
s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked()) s.eventManager.SetNotificationsEnabled(s.mNotifications.Checked())
s.eventManager.AddHandler(func(event *proto.SystemEvent) { s.eventManager.AddHandler(func(event *proto.SystemEvent) {
if event.Category == proto.SystemEvent_SYSTEM { if event.Category == proto.SystemEvent_SYSTEM {
@@ -1148,6 +1146,9 @@ func (s *serviceClient) onTrayReady() {
go s.eventManager.Start(s.ctx) go s.eventManager.Start(s.ctx)
go s.eventHandler.listen(s.ctx) go s.eventHandler.listen(s.ctx)
// Start sleep detection listener
go s.startSleepListener()
} }
func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File { func (s *serviceClient) attachOutput(cmd *exec.Cmd) *os.File {
@@ -1208,6 +1209,62 @@ func (s *serviceClient) getSrvClient(timeout time.Duration) (proto.DaemonService
return s.conn, nil return s.conn, nil
} }
// startSleepListener initializes the sleep detection service and listens for sleep events
func (s *serviceClient) startSleepListener() {
sleepService, err := sleep.New()
if err != nil {
log.Warnf("%v", err)
return
}
if err := sleepService.Register(s.handleSleepEvents); err != nil {
log.Errorf("failed to start sleep detection: %v", err)
return
}
log.Info("sleep detection service initialized")
// Cleanup on context cancellation
go func() {
<-s.ctx.Done()
log.Info("stopping sleep event listener")
if err := sleepService.Deregister(); err != nil {
log.Errorf("failed to deregister sleep detection: %v", err)
}
}()
}
// handleSleepEvents sends a sleep notification to the daemon via gRPC
func (s *serviceClient) handleSleepEvents(event sleep.EventType) {
conn, err := s.getSrvClient(0)
if err != nil {
log.Errorf("failed to get daemon client for sleep notification: %v", err)
return
}
req := &proto.OSLifecycleRequest{}
switch event {
case sleep.EventTypeWakeUp:
log.Infof("handle wakeup event: %v", event)
req.Type = proto.OSLifecycleRequest_WAKEUP
case sleep.EventTypeSleep:
log.Infof("handle sleep event: %v", event)
req.Type = proto.OSLifecycleRequest_SLEEP
default:
log.Infof("unknown event: %v", event)
return
}
_, err = conn.NotifyOSLifecycle(s.ctx, req)
if err != nil {
log.Errorf("failed to notify daemon about os lifecycle notification: %v", err)
return
}
log.Info("successfully notified daemon about os lifecycle")
}
// setSettingsEnabled enables or disables the settings menu based on the provided state // setSettingsEnabled enables or disables the settings menu based on the provided state
func (s *serviceClient) setSettingsEnabled(enabled bool) { func (s *serviceClient) setSettingsEnabled(enabled bool) {
if s.mSettings != nil { if s.mSettings != nil {
@@ -1491,7 +1548,7 @@ func (s *serviceClient) onUpdateAvailable(newVersion string, enforced bool) {
if enforced && s.lastNotifiedVersion != newVersion { if enforced && s.lastNotifiedVersion != newVersion {
s.lastNotifiedVersion = newVersion s.lastNotifiedVersion = newVersion
s.notifier.Send("Update available", "A new version "+newVersion+" is ready to install") s.app.SendNotification(fyne.NewNotification("Update available", "A new version "+newVersion+" is ready to install"))
} }
} }

View File

@@ -16,7 +16,6 @@ import (
"fyne.io/fyne/v2/widget" "fyne.io/fyne/v2/widget"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/skratchdot/open-golang/open" "github.com/skratchdot/open-golang/open"
"google.golang.org/protobuf/types/known/durationpb"
"github.com/netbirdio/netbird/client/internal" "github.com/netbirdio/netbird/client/internal"
"github.com/netbirdio/netbird/client/proto" "github.com/netbirdio/netbird/client/proto"
@@ -39,7 +38,6 @@ type debugCollectionParams struct {
upload bool upload bool
uploadURL string uploadURL string
enablePersistence bool enablePersistence bool
capture bool
} }
// UI components for progress tracking // UI components for progress tracking
@@ -53,58 +51,25 @@ type progressUI struct {
func (s *serviceClient) showDebugUI() { func (s *serviceClient) showDebugUI() {
w := s.app.NewWindow("NetBird Debug") w := s.app.NewWindow("NetBird Debug")
w.SetOnClosed(s.cancel) w.SetOnClosed(s.cancel)
w.Resize(fyne.NewSize(600, 500)) w.Resize(fyne.NewSize(600, 500))
w.SetFixedSize(true) w.SetFixedSize(true)
anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil) anonymizeCheck := widget.NewCheck("Anonymize sensitive information (public IPs, domains, ...)", nil)
systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil) systemInfoCheck := widget.NewCheck("Include system information (routes, interfaces, ...)", nil)
systemInfoCheck.SetChecked(true) systemInfoCheck.SetChecked(true)
captureCheck := widget.NewCheck("Include packet capture", nil)
uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil) uploadCheck := widget.NewCheck("Upload bundle automatically after creation", nil)
uploadCheck.SetChecked(true) uploadCheck.SetChecked(true)
uploadURLContainer, uploadURL := s.buildUploadSection(uploadCheck) uploadURLLabel := widget.NewLabel("Debug upload URL:")
debugModeContainer, runForDurationCheck, durationInput, noteLabel := s.buildDurationSection()
statusLabel := widget.NewLabel("")
statusLabel.Hide()
progressBar := widget.NewProgressBar()
progressBar.Hide()
createButton := widget.NewButton("Create Debug Bundle", nil)
uiControls := []fyne.Disableable{
anonymizeCheck, systemInfoCheck, captureCheck,
uploadCheck, uploadURL, runForDurationCheck, durationInput, createButton,
}
createButton.OnTapped = s.getCreateHandler(
statusLabel, progressBar, uploadCheck, uploadURL,
anonymizeCheck, systemInfoCheck, captureCheck,
runForDurationCheck, durationInput, uiControls, w,
)
content := container.NewVBox(
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
widget.NewLabel(""),
anonymizeCheck, systemInfoCheck, captureCheck,
uploadCheck, uploadURLContainer,
widget.NewLabel(""),
debugModeContainer, noteLabel,
widget.NewLabel(""),
statusLabel, progressBar, createButton,
)
w.SetContent(container.NewPadded(content))
w.Show()
}
func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Container, *widget.Entry) {
uploadURL := widget.NewEntry() uploadURL := widget.NewEntry()
uploadURL.SetText(uptypes.DefaultBundleURL) uploadURL.SetText(uptypes.DefaultBundleURL)
uploadURL.SetPlaceHolder("Enter upload URL") uploadURL.SetPlaceHolder("Enter upload URL")
uploadURLContainer := container.NewVBox(widget.NewLabel("Debug upload URL:"), uploadURL) uploadURLContainer := container.NewVBox(
uploadURLLabel,
uploadURL,
)
uploadCheck.OnChanged = func(checked bool) { uploadCheck.OnChanged = func(checked bool) {
if checked { if checked {
@@ -113,14 +78,13 @@ func (s *serviceClient) buildUploadSection(uploadCheck *widget.Check) (*fyne.Con
uploadURLContainer.Hide() uploadURLContainer.Hide()
} }
} }
return uploadURLContainer, uploadURL
}
func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check, *widget.Entry, *widget.Label) { debugModeContainer := container.NewHBox()
runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil) runForDurationCheck := widget.NewCheck("Run with trace logs before creating bundle", nil)
runForDurationCheck.SetChecked(true) runForDurationCheck.SetChecked(true)
forLabel := widget.NewLabel("for") forLabel := widget.NewLabel("for")
durationInput := widget.NewEntry() durationInput := widget.NewEntry()
durationInput.SetText("1") durationInput.SetText("1")
minutesLabel := widget.NewLabel("minute") minutesLabel := widget.NewLabel("minute")
@@ -144,8 +108,63 @@ func (s *serviceClient) buildDurationSection() (*fyne.Container, *widget.Check,
} }
} }
modeContainer := container.NewHBox(runForDurationCheck, forLabel, durationInput, minutesLabel) debugModeContainer.Add(runForDurationCheck)
return modeContainer, runForDurationCheck, durationInput, noteLabel debugModeContainer.Add(forLabel)
debugModeContainer.Add(durationInput)
debugModeContainer.Add(minutesLabel)
statusLabel := widget.NewLabel("")
statusLabel.Hide()
progressBar := widget.NewProgressBar()
progressBar.Hide()
createButton := widget.NewButton("Create Debug Bundle", nil)
// UI controls that should be disabled during debug collection
uiControls := []fyne.Disableable{
anonymizeCheck,
systemInfoCheck,
uploadCheck,
uploadURL,
runForDurationCheck,
durationInput,
createButton,
}
createButton.OnTapped = s.getCreateHandler(
statusLabel,
progressBar,
uploadCheck,
uploadURL,
anonymizeCheck,
systemInfoCheck,
runForDurationCheck,
durationInput,
uiControls,
w,
)
content := container.NewVBox(
widget.NewLabel("Create a debug bundle to help troubleshoot issues with NetBird"),
widget.NewLabel(""),
anonymizeCheck,
systemInfoCheck,
uploadCheck,
uploadURLContainer,
widget.NewLabel(""),
debugModeContainer,
noteLabel,
widget.NewLabel(""),
statusLabel,
progressBar,
createButton,
)
paddedContent := container.NewPadded(content)
w.SetContent(paddedContent)
w.Show()
} }
func validateMinute(s string, minutesLabel *widget.Label) error { func validateMinute(s string, minutesLabel *widget.Label) error {
@@ -181,7 +200,6 @@ func (s *serviceClient) getCreateHandler(
uploadURL *widget.Entry, uploadURL *widget.Entry,
anonymizeCheck *widget.Check, anonymizeCheck *widget.Check,
systemInfoCheck *widget.Check, systemInfoCheck *widget.Check,
captureCheck *widget.Check,
runForDurationCheck *widget.Check, runForDurationCheck *widget.Check,
duration *widget.Entry, duration *widget.Entry,
uiControls []fyne.Disableable, uiControls []fyne.Disableable,
@@ -204,7 +222,6 @@ func (s *serviceClient) getCreateHandler(
params := &debugCollectionParams{ params := &debugCollectionParams{
anonymize: anonymizeCheck.Checked, anonymize: anonymizeCheck.Checked,
systemInfo: systemInfoCheck.Checked, systemInfo: systemInfoCheck.Checked,
capture: captureCheck.Checked,
upload: uploadCheck.Checked, upload: uploadCheck.Checked,
uploadURL: url, uploadURL: url,
enablePersistence: true, enablePersistence: true,
@@ -236,7 +253,10 @@ func (s *serviceClient) getCreateHandler(
statusLabel.SetText("Creating debug bundle...") statusLabel.SetText("Creating debug bundle...")
go s.handleDebugCreation( go s.handleDebugCreation(
params, anonymizeCheck.Checked,
systemInfoCheck.Checked,
uploadCheck.Checked,
url,
statusLabel, statusLabel,
uiControls, uiControls,
w, w,
@@ -351,7 +371,7 @@ func startProgressTracker(ctx context.Context, wg *sync.WaitGroup, duration time
func (s *serviceClient) configureServiceForDebug( func (s *serviceClient) configureServiceForDebug(
conn proto.DaemonServiceClient, conn proto.DaemonServiceClient,
state *debugInitialState, state *debugInitialState,
params *debugCollectionParams, enablePersistence bool,
) { ) {
if state.wasDown { if state.wasDown {
if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil { if _, err := conn.Up(s.ctx, &proto.UpRequest{}); err != nil {
@@ -377,7 +397,7 @@ func (s *serviceClient) configureServiceForDebug(
time.Sleep(time.Second) time.Sleep(time.Second)
} }
if params.enablePersistence { if enablePersistence {
if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{ if _, err := conn.SetSyncResponsePersistence(s.ctx, &proto.SetSyncResponsePersistenceRequest{
Enabled: true, Enabled: true,
}); err != nil { }); err != nil {
@@ -397,26 +417,6 @@ func (s *serviceClient) configureServiceForDebug(
if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil { if _, err := conn.StartCPUProfile(s.ctx, &proto.StartCPUProfileRequest{}); err != nil {
log.Warnf("failed to start CPU profiling: %v", err) log.Warnf("failed to start CPU profiling: %v", err)
} }
s.startBundleCaptureIfEnabled(conn, params)
}
func (s *serviceClient) startBundleCaptureIfEnabled(conn proto.DaemonServiceClient, params *debugCollectionParams) {
if !params.capture {
return
}
const maxCapture = 10 * time.Minute
timeout := params.duration + 30*time.Second
if timeout > maxCapture {
timeout = maxCapture
log.Warnf("packet capture clamped to %s (server maximum)", maxCapture)
}
if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{
Timeout: durationpb.New(timeout),
}); err != nil {
log.Warnf("failed to start bundle capture: %v", err)
}
} }
func (s *serviceClient) collectDebugData( func (s *serviceClient) collectDebugData(
@@ -430,7 +430,7 @@ func (s *serviceClient) collectDebugData(
var wg sync.WaitGroup var wg sync.WaitGroup
startProgressTracker(ctx, &wg, params.duration, progress) startProgressTracker(ctx, &wg, params.duration, progress)
s.configureServiceForDebug(conn, state, params) s.configureServiceForDebug(conn, state, params.enablePersistence)
wg.Wait() wg.Wait()
progress.progressBar.Hide() progress.progressBar.Hide()
@@ -440,14 +440,6 @@ func (s *serviceClient) collectDebugData(
log.Warnf("failed to stop CPU profiling: %v", err) log.Warnf("failed to stop CPU profiling: %v", err)
} }
if params.capture {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
log.Warnf("failed to stop bundle capture: %v", err)
}
}
return nil return nil
} }
@@ -528,37 +520,18 @@ func handleError(progress *progressUI, errMsg string) {
} }
func (s *serviceClient) handleDebugCreation( func (s *serviceClient) handleDebugCreation(
params *debugCollectionParams, anonymize bool,
systemInfo bool,
upload bool,
uploadURL string,
statusLabel *widget.Label, statusLabel *widget.Label,
uiControls []fyne.Disableable, uiControls []fyne.Disableable,
w fyne.Window, w fyne.Window,
) { ) {
conn, err := s.getSrvClient(failFastTimeout) log.Infof("Creating debug bundle (Anonymized: %v, System Info: %v, Upload Attempt: %v)...",
if err != nil { anonymize, systemInfo, upload)
log.Errorf("Failed to get client for debug: %v", err)
statusLabel.SetText(fmt.Sprintf("Error: %v", err))
enableUIControls(uiControls)
return
}
if params.capture { resp, err := s.createDebugBundle(anonymize, systemInfo, uploadURL)
if _, err := conn.StartBundleCapture(s.ctx, &proto.StartBundleCaptureRequest{
Timeout: durationpb.New(30 * time.Second),
}); err != nil {
log.Warnf("failed to start bundle capture: %v", err)
} else {
defer func() {
stopCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
if _, err := conn.StopBundleCapture(stopCtx, &proto.StopBundleCaptureRequest{}); err != nil {
log.Warnf("failed to stop bundle capture: %v", err)
}
}()
time.Sleep(2 * time.Second)
}
}
resp, err := s.createDebugBundle(params.anonymize, params.systemInfo, params.uploadURL)
if err != nil { if err != nil {
log.Errorf("Failed to create debug bundle: %v", err) log.Errorf("Failed to create debug bundle: %v", err)
statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err)) statusLabel.SetText(fmt.Sprintf("Error creating bundle: %v", err))
@@ -570,7 +543,7 @@ func (s *serviceClient) handleDebugCreation(
uploadFailureReason := resp.GetUploadFailureReason() uploadFailureReason := resp.GetUploadFailureReason()
uploadedKey := resp.GetUploadedKey() uploadedKey := resp.GetUploadedKey()
if params.upload { if upload {
if uploadFailureReason != "" { if uploadFailureReason != "" {
showUploadFailedDialog(w, localPath, uploadFailureReason) showUploadFailedDialog(w, localPath, uploadFailureReason)
} else { } else {

View File

@@ -8,6 +8,7 @@ import (
"sync" "sync"
"time" "time"
"fyne.io/fyne/v2"
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
@@ -17,17 +18,11 @@ import (
"github.com/netbirdio/netbird/client/ui/desktop" "github.com/netbirdio/netbird/client/ui/desktop"
) )
// Notifier sends desktop notifications. Defined here so the event package
// does not depend on fyne or the platform-specific notifier implementation.
type Notifier interface {
Send(title, body string)
}
type Handler func(*proto.SystemEvent) type Handler func(*proto.SystemEvent)
type Manager struct { type Manager struct {
notifier Notifier app fyne.App
addr string addr string
mu sync.Mutex mu sync.Mutex
ctx context.Context ctx context.Context
@@ -36,10 +31,10 @@ type Manager struct {
handlers []Handler handlers []Handler
} }
func NewManager(notifier Notifier, addr string) *Manager { func NewManager(app fyne.App, addr string) *Manager {
return &Manager{ return &Manager{
notifier: notifier, app: app,
addr: addr, addr: addr,
} }
} }
@@ -119,7 +114,7 @@ func (e *Manager) handleEvent(event *proto.SystemEvent) {
if id != "" { if id != "" {
body += fmt.Sprintf(" ID: %s", id) body += fmt.Sprintf(" ID: %s", id)
} }
e.notifier.Send(title, body) e.app.SendNotification(fyne.NewNotification(title, body))
} }
for _, handler := range handlers { for _, handler := range handlers {

View File

@@ -9,6 +9,7 @@ import (
"os" "os"
"os/exec" "os/exec"
"fyne.io/fyne/v2"
"fyne.io/systray" "fyne.io/systray"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc/codes" "google.golang.org/grpc/codes"
@@ -86,7 +87,7 @@ func (h *eventHandler) handleConnectClick() {
if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) { if errors.Is(err, context.Canceled) || (ok && st.Code() == codes.Canceled) {
log.Debugf("connect operation cancelled by user") log.Debugf("connect operation cancelled by user")
} else { } else {
h.client.notifier.Send("Error", "Failed to connect") h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to connect"))
log.Errorf("connect failed: %v", err) log.Errorf("connect failed: %v", err)
} }
} }
@@ -111,7 +112,7 @@ func (h *eventHandler) handleDisconnectClick() {
if err := h.client.menuDownClick(); err != nil { if err := h.client.menuDownClick(); err != nil {
st, ok := status.FromError(err) st, ok := status.FromError(err)
if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) { if !errors.Is(err, context.Canceled) && !(ok && st.Code() == codes.Canceled) {
h.client.notifier.Send("Error", "Failed to disconnect") h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to disconnect"))
log.Errorf("disconnect failed: %v", err) log.Errorf("disconnect failed: %v", err)
} else { } else {
log.Debugf("disconnect cancelled or already disconnecting") log.Debugf("disconnect cancelled or already disconnecting")
@@ -129,7 +130,7 @@ func (h *eventHandler) handleAllowSSHClick() {
if err := h.updateConfigWithErr(); err != nil { if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error h.toggleCheckbox(h.client.mAllowSSH) // revert checkbox state on error
log.Errorf("failed to update config: %v", err) log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update SSH settings") h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update SSH settings"))
} }
} }
@@ -139,7 +140,7 @@ func (h *eventHandler) handleAutoConnectClick() {
if err := h.updateConfigWithErr(); err != nil { if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error h.toggleCheckbox(h.client.mAutoConnect) // revert checkbox state on error
log.Errorf("failed to update config: %v", err) log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update auto-connect settings") h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update auto-connect settings"))
} }
} }
@@ -148,7 +149,7 @@ func (h *eventHandler) handleRosenpassClick() {
if err := h.updateConfigWithErr(); err != nil { if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error h.toggleCheckbox(h.client.mEnableRosenpass) // revert checkbox state on error
log.Errorf("failed to update config: %v", err) log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update Rosenpass settings") h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update Rosenpass settings"))
} }
} }
@@ -157,7 +158,7 @@ func (h *eventHandler) handleLazyConnectionClick() {
if err := h.updateConfigWithErr(); err != nil { if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error h.toggleCheckbox(h.client.mLazyConnEnabled) // revert checkbox state on error
log.Errorf("failed to update config: %v", err) log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update lazy connection settings") h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update lazy connection settings"))
} }
} }
@@ -166,7 +167,7 @@ func (h *eventHandler) handleBlockInboundClick() {
if err := h.updateConfigWithErr(); err != nil { if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error h.toggleCheckbox(h.client.mBlockInbound) // revert checkbox state on error
log.Errorf("failed to update config: %v", err) log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update block inbound settings") h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update block inbound settings"))
} }
} }
@@ -175,7 +176,7 @@ func (h *eventHandler) handleNotificationsClick() {
if err := h.updateConfigWithErr(); err != nil { if err := h.updateConfigWithErr(); err != nil {
h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error h.toggleCheckbox(h.client.mNotifications) // revert checkbox state on error
log.Errorf("failed to update config: %v", err) log.Errorf("failed to update config: %v", err)
h.client.notifier.Send("Error", "Failed to update notifications settings") h.client.app.SendNotification(fyne.NewNotification("Error", "Failed to update notifications settings"))
} else if h.client.eventManager != nil { } else if h.client.eventManager != nil {
h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked()) h.client.eventManager.SetNotificationsEnabled(h.client.mNotifications.Checked())
} }

View File

@@ -1,27 +0,0 @@
// Package notifier sends desktop notifications. On Windows it uses the WinRT
// COM API directly via go-toast/v2 to avoid the PowerShell window flash that
// fyne's default implementation produces. On other platforms it delegates to
// fyne.
package notifier
import "fyne.io/fyne/v2"
// Notifier sends desktop notifications.
type Notifier interface {
Send(title, body string)
}
// New returns a platform-specific Notifier. The fyne app is used as the
// fallback notifier on platforms where no native implementation is wired up,
// and on Windows when the COM path fails to initialize.
func New(app fyne.App) Notifier {
return newNotifier(app)
}
type fyneNotifier struct {
app fyne.App
}
func (f *fyneNotifier) Send(title, body string) {
f.app.SendNotification(fyne.NewNotification(title, body))
}

View File

@@ -1,9 +0,0 @@
//go:build !windows
package notifier
import "fyne.io/fyne/v2"
func newNotifier(app fyne.App) Notifier {
return &fyneNotifier{app: app}
}

View File

@@ -1,88 +0,0 @@
package notifier
import (
"os"
"path/filepath"
"sync"
"fyne.io/fyne/v2"
toast "git.sr.ht/~jackmordaunt/go-toast/v2"
"git.sr.ht/~jackmordaunt/go-toast/v2/wintoast"
log "github.com/sirupsen/logrus"
)
const (
// appID is the AppUserModelID shown in the Windows Action Center. It
// must match the System.AppUserModel.ID property set on the Start Menu
// shortcut by the MSI (see client/netbird.wxs); otherwise Windows
// groups toasts under a separate, unbranded entry.
appID = "NetBird"
// appGUID identifies the COM activation callback class. Generated once
// for NetBird; do not change without coordinating an installer bump,
// since old registry entries pointing at the previous GUID would orphan.
appGUID = "{0E1B4DE7-E148-432B-9814-544F941826EC}"
)
type comNotifier struct {
fallback *fyneNotifier
ready bool
iconPath string
}
var (
initOnce sync.Once
initErr error
)
func newNotifier(app fyne.App) Notifier {
n := &comNotifier{
fallback: &fyneNotifier{app: app},
iconPath: resolveIcon(),
}
initOnce.Do(func() {
initErr = wintoast.SetAppData(wintoast.AppData{
AppID: appID,
GUID: appGUID,
IconPath: n.iconPath,
})
})
if initErr != nil {
log.Warnf("toast: register app data failed, falling back to fyne notifications: %v", initErr)
return n.fallback
}
n.ready = true
return n
}
func (n *comNotifier) Send(title, body string) {
if !n.ready {
n.fallback.Send(title, body)
return
}
notification := toast.Notification{
AppID: appID,
Title: title,
Body: body,
Icon: n.iconPath,
}
if err := notification.Push(); err != nil {
log.Warnf("toast: push failed, using fyne fallback: %v", err)
n.fallback.Send(title, body)
}
}
// resolveIcon returns an absolute path to the toast icon, or an empty string
// when no icon can be located. Windows requires a PNG/JPG for the
// AppUserModelId IconUri registry value; .ico is silently ignored.
func resolveIcon() string {
exe, err := os.Executable()
if err != nil {
return ""
}
candidate := filepath.Join(filepath.Dir(exe), "netbird.png")
if _, err := os.Stat(candidate); err == nil {
return candidate
}
return ""
}

View File

@@ -548,7 +548,7 @@ func (p *profileMenu) refresh() {
if err != nil { if err != nil {
log.Errorf("failed to switch profile: %v", err) log.Errorf("failed to switch profile: %v", err)
// show notification dialog // show notification dialog
p.serviceClient.notifier.Send("Error", "Failed to switch profile") p.app.SendNotification(fyne.NewNotification("Error", "Failed to switch profile"))
return return
} }
@@ -628,9 +628,9 @@ func (p *profileMenu) refresh() {
} }
if err := p.eventHandler.logout(p.ctx); err != nil { if err := p.eventHandler.logout(p.ctx); err != nil {
log.Errorf("logout failed: %v", err) log.Errorf("logout failed: %v", err)
p.serviceClient.notifier.Send("Error", "Failed to deregister") p.app.SendNotification(fyne.NewNotification("Error", "Failed to deregister"))
} else { } else {
p.serviceClient.notifier.Send("Success", "Deregistered successfully") p.app.SendNotification(fyne.NewNotification("Success", "Deregistered successfully"))
} }
} }
} }

View File

@@ -5,7 +5,6 @@ package main
import ( import (
"context" "context"
"fmt" "fmt"
"sync"
"syscall/js" "syscall/js"
"time" "time"
@@ -15,7 +14,6 @@ import (
netbird "github.com/netbirdio/netbird/client/embed" netbird "github.com/netbirdio/netbird/client/embed"
sshdetection "github.com/netbirdio/netbird/client/ssh/detection" sshdetection "github.com/netbirdio/netbird/client/ssh/detection"
nbstatus "github.com/netbirdio/netbird/client/status" nbstatus "github.com/netbirdio/netbird/client/status"
wasmcapture "github.com/netbirdio/netbird/client/wasm/internal/capture"
"github.com/netbirdio/netbird/client/wasm/internal/http" "github.com/netbirdio/netbird/client/wasm/internal/http"
"github.com/netbirdio/netbird/client/wasm/internal/rdp" "github.com/netbirdio/netbird/client/wasm/internal/rdp"
"github.com/netbirdio/netbird/client/wasm/internal/ssh" "github.com/netbirdio/netbird/client/wasm/internal/ssh"
@@ -461,95 +459,6 @@ func createSetLogLevelMethod(client *netbird.Client) js.Func {
}) })
} }
// createStartCaptureMethod creates the programmable packet capture method.
// Returns a JS interface with onpacket callback and stop() method.
//
// Usage from JavaScript:
//
// const cap = await client.startCapture({ filter: "tcp port 443", verbose: true })
// cap.onpacket = (line) => console.log(line)
// const stats = cap.stop()
func createStartCaptureMethod(client *netbird.Client) js.Func {
return js.FuncOf(func(_ js.Value, args []js.Value) any {
var opts js.Value
if len(args) > 0 {
opts = args[0]
}
return createPromise(func(resolve, reject js.Value) {
iface, err := wasmcapture.Start(client, opts)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err)))
return
}
resolve.Invoke(iface)
})
})
}
// captureMethods returns capture() and stopCapture() that share state for
// the console-log shortcut. capture() logs packets to the browser console
// and stopCapture() ends it, like Ctrl+C on the CLI.
//
// Usage from browser devtools console:
//
// await client.capture() // capture all packets
// await client.capture("tcp") // capture with filter
// await client.capture({filter: "host 10.0.0.1", verbose: true})
// client.stopCapture() // stop and print stats
func captureMethods(client *netbird.Client) (startFn, stopFn js.Func) {
var mu sync.Mutex
var active *wasmcapture.Handle
startFn = js.FuncOf(func(_ js.Value, args []js.Value) any {
var opts js.Value
if len(args) > 0 {
opts = args[0]
}
return createPromise(func(resolve, reject js.Value) {
mu.Lock()
defer mu.Unlock()
if active != nil {
active.Stop()
active = nil
}
h, err := wasmcapture.StartConsole(client, opts)
if err != nil {
reject.Invoke(js.ValueOf(fmt.Sprintf("start capture: %v", err)))
return
}
active = h
console := js.Global().Get("console")
console.Call("log", "[capture] started, call client.stopCapture() to stop")
resolve.Invoke(js.Undefined())
})
})
stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
mu.Lock()
defer mu.Unlock()
if active == nil {
js.Global().Get("console").Call("log", "[capture] no active capture")
return js.Undefined()
}
stats := active.Stop()
active = nil
console := js.Global().Get("console")
console.Call("log", fmt.Sprintf("[capture] stopped: %d packets, %d bytes, %d dropped",
stats.Packets, stats.Bytes, stats.Dropped))
return js.Undefined()
})
return startFn, stopFn
}
// createPromise is a helper to create JavaScript promises // createPromise is a helper to create JavaScript promises
func createPromise(handler func(resolve, reject js.Value)) js.Value { func createPromise(handler func(resolve, reject js.Value)) js.Value {
return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any { return js.Global().Get("Promise").New(js.FuncOf(func(_ js.Value, promiseArgs []js.Value) any {
@@ -612,11 +521,6 @@ func createClientObject(client *netbird.Client) js.Value {
obj["statusDetail"] = createStatusDetailMethod(client) obj["statusDetail"] = createStatusDetailMethod(client)
obj["getSyncResponse"] = createGetSyncResponseMethod(client) obj["getSyncResponse"] = createGetSyncResponseMethod(client)
obj["setLogLevel"] = createSetLogLevelMethod(client) obj["setLogLevel"] = createSetLogLevelMethod(client)
obj["startCapture"] = createStartCaptureMethod(client)
capStart, capStop := captureMethods(client)
obj["capture"] = capStart
obj["stopCapture"] = capStop
return js.ValueOf(obj) return js.ValueOf(obj)
} }

View File

@@ -1,176 +0,0 @@
//go:build js
// Package capture bridges the util/capture package to JavaScript via syscall/js.
package capture
import (
"strings"
"sync"
"syscall/js"
netbird "github.com/netbirdio/netbird/client/embed"
)
// Handle holds a running capture session so it can be stopped later.
type Handle struct {
cs *netbird.CaptureSession
stopFn js.Func
stopped bool
}
// Stop ends the capture and returns stats.
func (h *Handle) Stop() netbird.CaptureStats {
if h.stopped {
return h.cs.Stats()
}
h.stopped = true
h.stopFn.Release()
h.cs.Stop()
return h.cs.Stats()
}
func statsToJS(s netbird.CaptureStats) js.Value {
obj := js.Global().Get("Object").Call("create", js.Null())
obj.Set("packets", js.ValueOf(s.Packets))
obj.Set("bytes", js.ValueOf(s.Bytes))
obj.Set("dropped", js.ValueOf(s.Dropped))
return obj
}
// parseOpts extracts filter/verbose/ascii from a JS options value.
func parseOpts(jsOpts js.Value) (filter string, verbose, ascii bool) {
if jsOpts.IsNull() || jsOpts.IsUndefined() {
return
}
if jsOpts.Type() == js.TypeString {
filter = jsOpts.String()
return
}
if jsOpts.Type() != js.TypeObject {
return
}
if f := jsOpts.Get("filter"); !f.IsUndefined() && !f.IsNull() {
filter = f.String()
}
if v := jsOpts.Get("verbose"); !v.IsUndefined() {
verbose = v.Truthy()
}
if a := jsOpts.Get("ascii"); !a.IsUndefined() {
ascii = a.Truthy()
}
return
}
// Start creates a capture session and returns a JS interface for streaming text
// output. The returned object exposes:
//
// onpacket(callback) - set callback(string) for each text line
// stop() - stop capture and return stats { packets, bytes, dropped }
//
// Options: { filter: string, verbose: bool, ascii: bool } or just a filter string.
func Start(client *netbird.Client, jsOpts js.Value) (js.Value, error) {
filter, verbose, ascii := parseOpts(jsOpts)
cb := &jsCallbackWriter{}
cs, err := client.StartCapture(netbird.CaptureOptions{
TextOutput: cb,
Filter: filter,
Verbose: verbose,
ASCII: ascii,
})
if err != nil {
return js.Undefined(), err
}
handle := &Handle{cs: cs}
iface := js.Global().Get("Object").Call("create", js.Null())
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
return statsToJS(handle.Stop())
})
iface.Set("stop", handle.stopFn)
iface.Set("onpacket", js.Undefined())
cb.setInterface(iface)
return iface, nil
}
// StartConsole starts a capture that logs every packet line to console.log.
// Returns a Handle so the caller can stop it later.
func StartConsole(client *netbird.Client, jsOpts js.Value) (*Handle, error) {
filter, verbose, ascii := parseOpts(jsOpts)
cb := &jsCallbackWriter{}
cs, err := client.StartCapture(netbird.CaptureOptions{
TextOutput: cb,
Filter: filter,
Verbose: verbose,
ASCII: ascii,
})
if err != nil {
return nil, err
}
handle := &Handle{cs: cs}
handle.stopFn = js.FuncOf(func(_ js.Value, _ []js.Value) any {
return statsToJS(handle.Stop())
})
iface := js.Global().Get("Object").Call("create", js.Null())
console := js.Global().Get("console")
iface.Set("onpacket", console.Get("log").Call("bind", console, js.ValueOf("[capture]")))
cb.setInterface(iface)
return handle, nil
}
// jsCallbackWriter is an io.Writer that buffers text until a newline, then
// invokes the JS onpacket callback with each complete line.
type jsCallbackWriter struct {
mu sync.Mutex
iface js.Value
buf strings.Builder
}
func (w *jsCallbackWriter) setInterface(iface js.Value) {
w.mu.Lock()
defer w.mu.Unlock()
w.iface = iface
}
func (w *jsCallbackWriter) Write(p []byte) (int, error) {
w.mu.Lock()
w.buf.Write(p)
var lines []string
for {
str := w.buf.String()
idx := strings.IndexByte(str, '\n')
if idx < 0 {
break
}
lines = append(lines, str[:idx])
w.buf.Reset()
if idx+1 < len(str) {
w.buf.WriteString(str[idx+1:])
}
}
iface := w.iface
w.mu.Unlock()
if iface.IsUndefined() {
return len(p), nil
}
cb := iface.Get("onpacket")
if cb.IsUndefined() || cb.IsNull() {
return len(p), nil
}
for _, line := range lines {
cb.Invoke(js.ValueOf(line))
}
return len(p), nil
}

View File

@@ -13,9 +13,11 @@ import (
"github.com/cenkalti/backoff/v4" "github.com/cenkalti/backoff/v4"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"google.golang.org/grpc" "google.golang.org/grpc"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/credentials" "google.golang.org/grpc/credentials"
"google.golang.org/grpc/credentials/insecure" "google.golang.org/grpc/credentials/insecure"
"google.golang.org/grpc/keepalive" "google.golang.org/grpc/keepalive"
"google.golang.org/grpc/status"
nbgrpc "github.com/netbirdio/netbird/client/grpc" nbgrpc "github.com/netbirdio/netbird/client/grpc"
"github.com/netbirdio/netbird/flow/proto" "github.com/netbirdio/netbird/flow/proto"
@@ -299,11 +301,12 @@ func defaultBackoff(ctx context.Context, interval time.Duration) backoff.BackOff
}, ctx) }, ctx)
} }
// isContextDone reports whether the local context has been canceled or has
// exceeded its deadline. It deliberately does not inspect gRPC status codes:
// a server- or proxy-sent codes.Canceled / codes.DeadlineExceeded must not
// short-circuit our retry loop, since retrying is the correct response when
// the local context is still alive.
func isContextDone(err error) bool { func isContextDone(err error) bool {
return errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) if errors.Is(err, context.Canceled) || errors.Is(err, context.DeadlineExceeded) {
return true
}
if s, ok := status.FromError(err); ok {
return s.Code() == codes.Canceled || s.Code() == codes.DeadlineExceeded
}
return false
} }

5
go.mod
View File

@@ -30,7 +30,6 @@ require (
require ( require (
fyne.io/fyne/v2 v2.7.0 fyne.io/fyne/v2 v2.7.0
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3
github.com/awnumar/memguard v0.23.0 github.com/awnumar/memguard v0.23.0
github.com/aws/aws-sdk-go-v2 v1.38.3 github.com/aws/aws-sdk-go-v2 v1.38.3
github.com/aws/aws-sdk-go-v2/config v1.31.6 github.com/aws/aws-sdk-go-v2/config v1.31.6
@@ -47,7 +46,6 @@ require (
github.com/crowdsecurity/go-cs-bouncer v0.0.21 github.com/crowdsecurity/go-cs-bouncer v0.0.21
github.com/dexidp/dex v0.0.0-00010101000000-000000000000 github.com/dexidp/dex v0.0.0-00010101000000-000000000000
github.com/dexidp/dex/api/v2 v2.4.0 github.com/dexidp/dex/api/v2 v2.4.0
github.com/ebitengine/purego v0.8.4
github.com/eko/gocache/lib/v4 v4.2.0 github.com/eko/gocache/lib/v4 v4.2.0
github.com/eko/gocache/store/go_cache/v4 v4.2.2 github.com/eko/gocache/store/go_cache/v4 v4.2.2
github.com/eko/gocache/store/redis/v4 v4.2.2 github.com/eko/gocache/store/redis/v4 v4.2.2
@@ -71,7 +69,6 @@ require (
github.com/libp2p/go-netroute v0.2.1 github.com/libp2p/go-netroute v0.2.1
github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81 github.com/lrh3321/ipset-go v0.0.0-20250619021614-54a0a98ace81
github.com/mdlayher/socket v0.5.1 github.com/mdlayher/socket v0.5.1
github.com/mdp/qrterminal/v3 v3.2.1
github.com/miekg/dns v1.1.59 github.com/miekg/dns v1.1.59
github.com/mitchellh/hashstructure/v2 v2.0.2 github.com/mitchellh/hashstructure/v2 v2.0.2
github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42 github.com/netbirdio/management-integrations/integrations v0.0.0-20260416123949-2355d972be42
@@ -181,6 +178,7 @@ require (
github.com/docker/docker v28.0.1+incompatible // indirect github.com/docker/docker v28.0.1+incompatible // indirect
github.com/docker/go-connections v0.6.0 // indirect github.com/docker/go-connections v0.6.0 // indirect
github.com/docker/go-units v0.5.0 // indirect github.com/docker/go-units v0.5.0 // indirect
github.com/ebitengine/purego v0.8.4 // indirect
github.com/felixge/httpsnoop v1.0.4 // indirect github.com/felixge/httpsnoop v1.0.4 // indirect
github.com/fredbi/uri v1.1.1 // indirect github.com/fredbi/uri v1.1.1 // indirect
github.com/fyne-io/gl-js v0.2.0 // indirect github.com/fyne-io/gl-js v0.2.0 // indirect
@@ -310,7 +308,6 @@ require (
gopkg.in/square/go-jose.v2 v2.6.0 // indirect gopkg.in/square/go-jose.v2 v2.6.0 // indirect
gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect gopkg.in/tomb.v1 v1.0.0-20141024135613-dd632973f1e7 // indirect
gopkg.in/yaml.v2 v2.4.0 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect
rsc.io/qr v0.2.0 // indirect
) )
replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502 replace github.com/kardianos/service => github.com/netbirdio/service v0.0.0-20240911161631-f62744f42502

6
go.sum
View File

@@ -15,8 +15,6 @@ fyne.io/fyne/v2 v2.7.0 h1:GvZSpE3X0liU/fqstInVvRsaboIVpIWQ4/sfjDGIGGQ=
fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE= fyne.io/fyne/v2 v2.7.0/go.mod h1:xClVlrhxl7D+LT+BWYmcrW4Nf+dJTvkhnPgji7spAwE=
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4= fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9 h1:829+77I4TaMrcg9B3wf+gHhdSgoCVEgH2czlPXPbfj4=
fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs= fyne.io/systray v1.12.1-0.20260116214250-81f8e1a496f9/go.mod h1:RVwqP9nYMo7h5zViCBHri2FgjXF7H2cub7MAq4NSoLs=
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3 h1:N3IGoHHp9pb6mj1cbXbuaSXV/UMKwmbKLf53nQmtqMA=
git.sr.ht/~jackmordaunt/go-toast/v2 v2.0.3/go.mod h1:QtOLZGz8olr4qH2vWK0QH0w0O4T9fEIjMuWpKUsH7nc=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24 h1:bvDV9vkmnHYOMsOr4WLk+Vo07yKIzd94sVoIqshQ4bU=
github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8= github.com/AdaLogics/go-fuzz-headers v0.0.0-20230811130428-ced1acdcaa24/go.mod h1:8o94RPi1/7XTJvwPpRSzSUedZrtlirdB3r9Z20bi2f8=
github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk= github.com/AppsFlyer/go-sundheit v0.6.0 h1:d2hBvCjBSb2lUsEWGfPigr4MCOt04sxB+Rppl0yUMSk=
@@ -415,8 +413,6 @@ github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42 h1:A1Cq6Ysb0GM0
github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o= github.com/mdlayher/netlink v1.7.3-0.20250113171957-fbb4dce95f42/go.mod h1:BB4YCPDOzfy7FniQ/lxuYQ3dgmM2cZumHbK8RpTjN2o=
github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos= github.com/mdlayher/socket v0.5.1 h1:VZaqt6RkGkt2OE9l3GcC6nZkqD3xKeQLyfleW/uBcos=
github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ= github.com/mdlayher/socket v0.5.1/go.mod h1:TjPLHI1UgwEv5J1B5q0zTZq12A/6H7nKmtTanQE37IQ=
github.com/mdp/qrterminal/v3 v3.2.1 h1:6+yQjiiOsSuXT5n9/m60E54vdgFsw0zhADHhHLrFet4=
github.com/mdp/qrterminal/v3 v3.2.1/go.mod h1:jOTmXvnBsMy5xqLniO0R++Jmjs2sTm9dFSuQ5kpz/SU=
github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k= github.com/mholt/acmez/v2 v2.0.1 h1:3/3N0u1pLjMK4sNEAFSI+bcvzbPhRpY383sy1kLHJ6k=
github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U= github.com/mholt/acmez/v2 v2.0.1/go.mod h1:fX4c9r5jYwMyMsC+7tkYRxHibkOTgta5DIFGoe67e1U=
github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs= github.com/miekg/dns v1.1.59 h1:C9EXc/UToRwKLhK5wKU/I4QVsBUc8kE6MkHBkeypWZs=
@@ -917,5 +913,3 @@ gotest.tools/v3 v3.5.1 h1:EENdUnS3pdur5nybKYIh2Vfgc8IUNBjxDPSjtiJcOzU=
gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU= gotest.tools/v3 v3.5.1/go.mod h1:isy3WKz7GK6uNw/sbHzfKBLvlvXwUyV06n6brMxxopU=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA= gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89 h1:mGJaeA61P8dEHTqdvAgc70ZIV3QoUoJcXCRyyjO26OA=
gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q= gvisor.dev/gvisor v0.0.0-20260219192049-0f2374377e89/go.mod h1:QkHjoMIBaYtpVufgwv3keYAbln78mBoCuShZrPrer1Q=
rsc.io/qr v0.2.0 h1:6vBLea5/NRMVTz8V66gipeLycZMl/+UlFmk8DvqQ6WY=
rsc.io/qr v0.2.0/go.mod h1:IF+uZjkb9fqyeF/4tlBoynqmQxUoPfWEKh921coOuXs=

View File

@@ -193,7 +193,7 @@ func (c *Connector) ToStorageConnector() (storage.Connector, error) {
// are stored with types that Dex can open. // are stored with types that Dex can open.
func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) { func mapConnectorToDex(connType string, config map[string]interface{}) (string, map[string]interface{}) {
switch connType { switch connType {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs": case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
return "oidc", applyOIDCDefaults(connType, config) return "oidc", applyOIDCDefaults(connType, config)
default: default:
return connType, config return connType, config
@@ -218,8 +218,6 @@ func applyOIDCDefaults(connType string, config map[string]interface{}) map[strin
setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"}) setDefault(augmented, "claimMapping", map[string]string{"email": "preferred_username"})
case "okta", "pocketid": case "okta", "pocketid":
augmented["scopes"] = []string{"openid", "profile", "email", "groups"} augmented["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
augmented["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
} }
return augmented return augmented

View File

@@ -168,7 +168,7 @@ func (p *Provider) buildStorageConnector(cfg *ConnectorConfig) (storage.Connecto
var err error var err error
switch cfg.Type { switch cfg.Type {
case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak", "adfs": case "oidc", "zitadel", "entra", "okta", "pocketid", "authentik", "keycloak":
dexType = "oidc" dexType = "oidc"
configData, err = buildOIDCConnectorConfig(cfg, redirectURI) configData, err = buildOIDCConnectorConfig(cfg, redirectURI)
case "google": case "google":
@@ -220,8 +220,6 @@ func buildOIDCConnectorConfig(cfg *ConnectorConfig, redirectURI string) ([]byte,
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "pocketid": case "pocketid":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"} oidcConfig["scopes"] = []string{"openid", "profile", "email", "groups"}
case "adfs":
oidcConfig["scopes"] = []string{"openid", "profile", "email", "allatclaims"}
} }
return encodeConnectorConfig(oidcConfig) return encodeConnectorConfig(oidcConfig)
} }
@@ -285,7 +283,7 @@ func inferIdentityProviderType(dexType, connectorID string, _ map[string]interfa
// inferOIDCProviderType infers the specific OIDC provider from connector ID // inferOIDCProviderType infers the specific OIDC provider from connector ID
func inferOIDCProviderType(connectorID string) string { func inferOIDCProviderType(connectorID string) string {
connectorIDLower := strings.ToLower(connectorID) connectorIDLower := strings.ToLower(connectorID)
for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak", "adfs"} { for _, provider := range []string{"pocketid", "zitadel", "entra", "okta", "authentik", "keycloak"} {
if strings.Contains(connectorIDLower, provider) { if strings.Contains(connectorIDLower, provider) {
return provider return provider
} }

View File

@@ -231,20 +231,7 @@ get_upstream_host() {
wait_management_proxy() { wait_management_proxy() {
local proxy_container="${1:-traefik}" local proxy_container="${1:-traefik}"
local use_docker_logs=false
set +e set +e
if [[ "$proxy_container" == "detect-traefik" ]]; then
proxy_container=$(docker ps --format "{{.ID}}\t{{.Image}}\t{{.Ports}}" \
| awk -F'\t' '$2 ~ /traefik/ && $3 ~ /:(80|443)->/ {print $1; exit}')
if [[ -z "$proxy_container" ]]; then
echo "Warning: could not auto-detect Traefik container, log output will be skipped on timeout." > /dev/stderr
else
use_docker_logs=true
fi
fi
echo -n "Waiting for NetBird server to become ready" echo -n "Waiting for NetBird server to become ready"
counter=1 counter=1
while true; do while true; do
@@ -255,13 +242,7 @@ wait_management_proxy() {
if [[ $counter -eq 60 ]]; then if [[ $counter -eq 60 ]]; then
echo "" echo ""
echo "Taking too long. Checking logs..." echo "Taking too long. Checking logs..."
if [[ -n "$proxy_container" ]]; then $DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
if [[ "$use_docker_logs" == "true" ]]; then
docker logs --tail=20 "$proxy_container"
else
$DOCKER_COMPOSE_COMMAND logs --tail=20 "$proxy_container"
fi
fi
$DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server $DOCKER_COMPOSE_COMMAND logs --tail=20 netbird-server
fi fi
echo -n " ." echo -n " ."
@@ -537,7 +518,7 @@ start_services_and_show_instructions() {
$DOCKER_COMPOSE_COMMAND up -d $DOCKER_COMPOSE_COMMAND up -d
sleep 3 sleep 3
wait_management_proxy detect-traefik wait_management_direct
echo -e "$MSG_DONE" echo -e "$MSG_DONE"
print_post_setup_instructions print_post_setup_instructions

View File

@@ -7,6 +7,7 @@ import (
"os" "os"
"slices" "slices"
"strconv" "strconv"
"strings"
"sync" "sync"
"sync/atomic" "sync/atomic"
"time" "time"
@@ -15,9 +16,11 @@ import (
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
"golang.org/x/mod/semver" "golang.org/x/mod/semver"
nbdns "github.com/netbirdio/netbird/dns"
"github.com/netbirdio/netbird/management/internals/controllers/network_map" "github.com/netbirdio/netbird/management/internals/controllers/network_map"
"github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache" "github.com/netbirdio/netbird/management/internals/controllers/network_map/controller/cache"
"github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral" "github.com/netbirdio/netbird/management/internals/modules/peers/ephemeral"
"github.com/netbirdio/netbird/management/internals/modules/zones"
"github.com/netbirdio/netbird/management/internals/server/config" "github.com/netbirdio/netbird/management/internals/server/config"
"github.com/netbirdio/netbird/management/internals/shared/grpc" "github.com/netbirdio/netbird/management/internals/shared/grpc"
"github.com/netbirdio/netbird/management/server/account" "github.com/netbirdio/netbird/management/server/account"
@@ -55,6 +58,13 @@ type Controller struct {
proxyController port_forwarding.Controller proxyController port_forwarding.Controller
integratedPeerValidator integrated_validator.IntegratedValidator integratedPeerValidator integrated_validator.IntegratedValidator
holder *types.Holder
expNewNetworkMap bool
expNewNetworkMapAIDs map[string]struct{}
compactedNetworkMap bool
} }
type bufferUpdate struct { type bufferUpdate struct {
@@ -71,6 +81,29 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
log.Fatal(fmt.Errorf("error creating metrics: %w", err)) log.Fatal(fmt.Errorf("error creating metrics: %w", err))
} }
newNetworkMapBuilder, err := strconv.ParseBool(os.Getenv(network_map.EnvNewNetworkMapBuilder))
if err != nil {
log.WithContext(ctx).Warnf("failed to parse %s, using default value false: %v", network_map.EnvNewNetworkMapBuilder, err)
newNetworkMapBuilder = false
}
compactedNetworkMap := true
compactedEnv := os.Getenv(types.EnvNewNetworkMapCompacted)
parsedCompactedNmap, err := strconv.ParseBool(compactedEnv)
if err != nil && len(compactedEnv) > 0 {
log.WithContext(ctx).Warnf("failed to parse %s, using default value true: %v", types.EnvNewNetworkMapCompacted, err)
}
if err == nil && !parsedCompactedNmap {
log.WithContext(ctx).Info("disabling compacted mode")
compactedNetworkMap = false
}
ids := strings.Split(os.Getenv(network_map.EnvNewNetworkMapAccounts), ",")
expIDs := make(map[string]struct{}, len(ids))
for _, id := range ids {
expIDs[id] = struct{}{}
}
return &Controller{ return &Controller{
repo: newRepository(store), repo: newRepository(store),
metrics: nMetrics, metrics: nMetrics,
@@ -84,6 +117,12 @@ func NewController(ctx context.Context, store store.Store, metrics telemetry.App
proxyController: proxyController, proxyController: proxyController,
EphemeralPeersManager: ephemeralPeersManager, EphemeralPeersManager: ephemeralPeersManager,
holder: types.NewHolder(),
expNewNetworkMap: newNetworkMapBuilder,
expNewNetworkMapAIDs: expIDs,
compactedNetworkMap: compactedNetworkMap,
} }
} }
@@ -114,9 +153,17 @@ func (c *Controller) CountStreams() int {
func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error { func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName()) log.WithContext(ctx).Tracef("updating peers for account %s from %s", accountID, util.GetCallerName())
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) var (
if err != nil { account *types.Account
return fmt.Errorf("failed to get account: %v", err) err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return fmt.Errorf("failed to get account: %v", err)
}
} }
globalStart := time.Now() globalStart := time.Now()
@@ -150,6 +197,10 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
routers := account.GetResourceRoutersMap() routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers() groupIDToUserIDs := account.GetActiveGroupUsers()
if c.experimentalNetworkMap(accountID) {
c.initNetworkMapBuilderIfNeeded(account, approvedPeersMap)
}
proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers) proxyNetworkMaps, err := c.proxyController.GetProxyNetworkMapsAll(ctx, accountID, account.Peers)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err) log.WithContext(ctx).Errorf("failed to get proxy network maps: %v", err)
@@ -192,7 +243,16 @@ func (c *Controller) sendUpdateAccountPeers(ctx context.Context, accountID strin
c.metrics.CountCalcPostureChecksDuration(time.Since(start)) c.metrics.CountCalcPostureChecksDuration(time.Since(start))
start = time.Now() start = time.Now()
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountID):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, p.AccountID, p.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, p.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start)) c.metrics.CountCalcPeerNetworkMapDuration(time.Since(start))
@@ -257,10 +317,11 @@ func (c *Controller) bufferSendUpdateAccountPeers(ctx context.Context, accountID
// UpdatePeers updates all peers that belong to an account. // UpdatePeers updates all peers that belong to an account.
// Should be called when changes have to be synced to peers. // Should be called when changes have to be synced to peers.
func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { func (c *Controller) UpdateAccountPeers(ctx context.Context, accountID string) error {
if c.accountManagerMetrics != nil { if err := c.RecalculateNetworkMapCache(ctx, accountID); err != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation)) return fmt.Errorf("recalculate network map cache: %v", err)
} }
return c.sendUpdateAccountPeers(ctx, accountID) return c.sendUpdateAccountPeers(ctx, accountID)
} }
@@ -310,7 +371,16 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return err return err
} }
remotePeerNetworkMap := account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) var remotePeerNetworkMap *types.NetworkMap
switch {
case c.experimentalNetworkMap(accountId):
remotePeerNetworkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
case c.compactedNetworkMap:
remotePeerNetworkMap = account.GetPeerNetworkMapFromComponents(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
default:
remotePeerNetworkMap = account.GetPeerNetworkMap(ctx, peerId, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -334,13 +404,9 @@ func (c *Controller) UpdateAccountPeer(ctx context.Context, accountId string, pe
return nil return nil
} }
func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName()) log.WithContext(ctx).Tracef("buffer updating peers for account %s from %s", accountID, util.GetCallerName())
if c.accountManagerMetrics != nil {
c.accountManagerMetrics.CountUpdateAccountPeersTriggered(string(reason.Resource), string(reason.Operation))
}
bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{}) bufUpd, _ := c.accountUpdateLocks.LoadOrStore(accountID, &bufferUpdate{})
b := bufUpd.(*bufferUpdate) b := bufUpd.(*bufferUpdate)
@@ -355,14 +421,14 @@ func (c *Controller) BufferUpdateAccountPeers(ctx context.Context, accountID str
go func() { go func() {
defer b.mu.Unlock() defer b.mu.Unlock()
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.UpdateAccountPeers(ctx, accountID)
if !b.update.Load() { if !b.update.Load() {
return return
} }
b.update.Store(false) b.update.Store(false)
if b.next == nil { if b.next == nil {
b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() { b.next = time.AfterFunc(time.Duration(c.updateAccountPeersBufferInterval.Load()), func() {
_ = c.sendUpdateAccountPeers(ctx, accountID) _ = c.UpdateAccountPeers(ctx, accountID)
}) })
return return
} }
@@ -385,9 +451,17 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, emptyMap, nil, 0, nil return peer, emptyMap, nil, 0, nil
} }
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID) var (
if err != nil { account *types.Account
return nil, nil, nil, 0, err err error
)
if c.experimentalNetworkMap(accountID) {
account = c.getAccountFromHolderOrInit(ctx, accountID)
} else {
account, err = c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return nil, nil, nil, 0, err
}
} }
account.InjectProxyPolicies(ctx) account.InjectProxyPolicies(ctx)
@@ -419,10 +493,20 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return nil, nil, nil, 0, err return nil, nil, nil, 0, err
} }
resourcePolicies := account.GetResourcePoliciesMap() var networkMap *types.NetworkMap
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers() if c.experimentalNetworkMap(accountID) {
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs) networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peer.ID, approvedPeersMap, peersCustomZone, accountZones, c.accountManagerMetrics)
} else {
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, approvedPeersMap, resourcePolicies, routers, c.accountManagerMetrics, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {
@@ -434,6 +518,108 @@ func (c *Controller) GetValidatedPeerWithMap(ctx context.Context, isRequiresAppr
return peer, networkMap, postureChecks, dnsFwdPort, nil return peer, networkMap, postureChecks, dnsFwdPort, nil
} }
func (c *Controller) initNetworkMapBuilderIfNeeded(account *types.Account, validatedPeers map[string]struct{}) {
c.enrichAccountFromHolder(account)
account.InitNetworkMapBuilderIfNeeded(validatedPeers)
}
func (c *Controller) getPeerNetworkMapExp(
ctx context.Context,
accountId string,
peerId string,
validatedPeers map[string]struct{},
peersCustomZone nbdns.CustomZone,
accountZones []*zones.Zone,
metrics *telemetry.AccountManagerMetrics,
) *types.NetworkMap {
account := c.getAccountFromHolderOrInit(ctx, accountId)
if account == nil {
log.WithContext(ctx).Warnf("account %s not found in holder when getting peer network map", accountId)
return &types.NetworkMap{
Network: &types.Network{},
}
}
return account.GetPeerNetworkMapExp(ctx, peerId, peersCustomZone, accountZones, validatedPeers, metrics)
}
func (c *Controller) onPeersAddedUpdNetworkMapCache(account *types.Account, peerIds ...string) {
c.enrichAccountFromHolder(account)
account.OnPeersAddedUpdNetworkMapCache(peerIds...)
}
func (c *Controller) onPeerDeletedUpdNetworkMapCache(account *types.Account, peerId string) error {
c.enrichAccountFromHolder(account)
return account.OnPeerDeletedUpdNetworkMapCache(peerId)
}
func (c *Controller) UpdatePeerInNetworkMapCache(accountId string, peer *nbpeer.Peer) {
account := c.getAccountFromHolder(accountId)
if account == nil {
return
}
account.UpdatePeerInNetworkMapCache(peer)
}
func (c *Controller) recalculateNetworkMapCache(account *types.Account, validatedPeers map[string]struct{}) {
account.RecalculateNetworkMapCache(validatedPeers)
c.updateAccountInHolder(account)
}
func (c *Controller) RecalculateNetworkMapCache(ctx context.Context, accountId string) error {
if c.experimentalNetworkMap(accountId) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountId)
if err != nil {
return err
}
validatedPeers, err := c.integratedPeerValidator.GetValidatedPeers(ctx, account.Id, maps.Values(account.Groups), maps.Values(account.Peers), account.Settings.Extra)
if err != nil {
log.WithContext(ctx).Errorf("failed to get validate peers: %v", err)
return err
}
c.recalculateNetworkMapCache(account, validatedPeers)
}
return nil
}
func (c *Controller) experimentalNetworkMap(accountId string) bool {
_, ok := c.expNewNetworkMapAIDs[accountId]
return c.expNewNetworkMap || ok
}
func (c *Controller) enrichAccountFromHolder(account *types.Account) {
a := c.holder.GetAccount(account.Id)
if a == nil {
c.holder.AddAccount(account)
return
}
account.NetworkMapCache = a.NetworkMapCache
if account.NetworkMapCache == nil {
return
}
c.holder.AddAccount(account)
}
func (c *Controller) getAccountFromHolder(accountID string) *types.Account {
return c.holder.GetAccount(accountID)
}
func (c *Controller) getAccountFromHolderOrInit(ctx context.Context, accountID string) *types.Account {
a := c.holder.GetAccount(accountID)
if a != nil {
return a
}
account, err := c.holder.LoadOrStoreFunc(ctx, accountID, c.requestBuffer.GetAccountWithBackpressure)
if err != nil {
return nil
}
return account
}
func (c *Controller) updateAccountInHolder(account *types.Account) {
c.holder.AddAccount(account)
}
// GetDNSDomain returns the configured dnsDomain // GetDNSDomain returns the configured dnsDomain
func (c *Controller) GetDNSDomain(settings *types.Settings) string { func (c *Controller) GetDNSDomain(settings *types.Settings) string {
if settings == nil { if settings == nil {
@@ -570,7 +756,16 @@ func isPeerInPolicySourceGroups(account *types.Account, peerID string, policy *t
} }
func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error { func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerIDs []string) error {
err := c.bufferSendUpdateAccountPeers(ctx, accountID) peers, err := c.repo.GetPeersByIDs(ctx, accountID, peerIDs)
if err != nil {
return fmt.Errorf("failed to get peers by ids: %w", err)
}
for _, peer := range peers {
c.UpdatePeerInNetworkMapCache(accountID, peer)
}
err = c.bufferSendUpdateAccountPeers(ctx, accountID)
if err != nil { if err != nil {
log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err) log.WithContext(ctx).Errorf("failed to buffer update account peers for peer update in account %s: %v", accountID, err)
} }
@@ -580,6 +775,14 @@ func (c *Controller) OnPeersUpdated(ctx context.Context, accountID string, peerI
func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error { func (c *Controller) OnPeersAdded(ctx context.Context, accountID string, peerIDs []string) error {
log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs) log.WithContext(ctx).Debugf("OnPeersAdded call to add peers: %v", peerIDs)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
return err
}
log.WithContext(ctx).Debugf("peers are ready to be added to networkmap cache: %v", peerIDs)
c.onPeersAddedUpdNetworkMapCache(account, peerIDs...)
}
return c.bufferSendUpdateAccountPeers(ctx, accountID) return c.bufferSendUpdateAccountPeers(ctx, accountID)
} }
@@ -614,6 +817,19 @@ func (c *Controller) OnPeersDeleted(ctx context.Context, accountID string, peerI
MessageType: network_map.MessageTypeNetworkMap, MessageType: network_map.MessageTypeNetworkMap,
}) })
c.peersUpdateManager.CloseChannel(ctx, peerID) c.peersUpdateManager.CloseChannel(ctx, peerID)
if c.experimentalNetworkMap(accountID) {
account, err := c.requestBuffer.GetAccountWithBackpressure(ctx, accountID)
if err != nil {
log.WithContext(ctx).Errorf("failed to get account %s: %v", accountID, err)
continue
}
err = c.onPeerDeletedUpdNetworkMapCache(account, peerID)
if err != nil {
log.WithContext(ctx).Errorf("failed to update network map cache for deleted peer %s in account %s: %v", peerID, accountID, err)
continue
}
}
} }
return c.bufferSendUpdateAccountPeers(ctx, accountID) return c.bufferSendUpdateAccountPeers(ctx, accountID)
@@ -656,11 +872,21 @@ func (c *Controller) GetNetworkMap(ctx context.Context, peerID string) (*types.N
return nil, err return nil, err
} }
account.InjectProxyPolicies(ctx) var networkMap *types.NetworkMap
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap() if c.experimentalNetworkMap(peer.AccountID) {
groupIDToUserIDs := account.GetActiveGroupUsers() networkMap = c.getPeerNetworkMapExp(ctx, peer.AccountID, peerID, validatedPeers, peersCustomZone, accountZones, nil)
networkMap := account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs) } else {
account.InjectProxyPolicies(ctx)
resourcePolicies := account.GetResourcePoliciesMap()
routers := account.GetResourceRoutersMap()
groupIDToUserIDs := account.GetActiveGroupUsers()
if c.compactedNetworkMap {
networkMap = account.GetPeerNetworkMapFromComponents(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
} else {
networkMap = account.GetPeerNetworkMap(ctx, peer.ID, peersCustomZone, accountZones, validatedPeers, resourcePolicies, routers, nil, groupIDToUserIDs)
}
}
proxyNetworkMap, ok := proxyNetworkMaps[peer.ID] proxyNetworkMap, ok := proxyNetworkMaps[peer.ID]
if ok { if ok {

View File

@@ -12,15 +12,18 @@ import (
) )
const ( const (
EnvNewNetworkMapBuilder = "NB_EXPERIMENT_NETWORK_MAP"
EnvNewNetworkMapAccounts = "NB_EXPERIMENT_NETWORK_MAP_ACCOUNTS"
DnsForwarderPort = nbdns.ForwarderServerPort DnsForwarderPort = nbdns.ForwarderServerPort
OldForwarderPort = nbdns.ForwarderClientPort OldForwarderPort = nbdns.ForwarderClientPort
DnsForwarderPortMinVersion = "v0.59.0" DnsForwarderPortMinVersion = "v0.59.0"
) )
type Controller interface { type Controller interface {
UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error UpdateAccountPeers(ctx context.Context, accountID string) error
UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error UpdateAccountPeer(ctx context.Context, accountId string, peerId string) error
BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error BufferUpdateAccountPeers(ctx context.Context, accountID string) error
GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error) GetValidatedPeerWithMap(ctx context.Context, isRequiresApproval bool, accountID string, p *nbpeer.Peer) (*nbpeer.Peer, *types.NetworkMap, []*posture.Checks, int64, error)
GetDNSDomain(settings *types.Settings) string GetDNSDomain(settings *types.Settings) string
StartWarmup(context.Context) StartWarmup(context.Context)

View File

@@ -44,17 +44,17 @@ func (m *MockController) EXPECT() *MockControllerMockRecorder {
} }
// BufferUpdateAccountPeers mocks base method. // BufferUpdateAccountPeers mocks base method.
func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { func (m *MockController) BufferUpdateAccountPeers(ctx context.Context, accountID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID, reason) ret := m.ctrl.Call(m, "BufferUpdateAccountPeers", ctx, accountID)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers. // BufferUpdateAccountPeers indicates an expected call of BufferUpdateAccountPeers.
func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID, reason any) *gomock.Call { func (mr *MockControllerMockRecorder) BufferUpdateAccountPeers(ctx, accountID any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID, reason) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "BufferUpdateAccountPeers", reflect.TypeOf((*MockController)(nil).BufferUpdateAccountPeers), ctx, accountID)
} }
// CountStreams mocks base method. // CountStreams mocks base method.
@@ -238,15 +238,15 @@ func (mr *MockControllerMockRecorder) UpdateAccountPeer(ctx, accountId, peerId a
} }
// UpdateAccountPeers mocks base method. // UpdateAccountPeers mocks base method.
func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) error { func (m *MockController) UpdateAccountPeers(ctx context.Context, accountID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID, reason) ret := m.ctrl.Call(m, "UpdateAccountPeers", ctx, accountID)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// UpdateAccountPeers indicates an expected call of UpdateAccountPeers. // UpdateAccountPeers indicates an expected call of UpdateAccountPeers.
func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID, reason any) *gomock.Call { func (mr *MockControllerMockRecorder) UpdateAccountPeers(ctx, accountID any) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID, reason) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "UpdateAccountPeers", reflect.TypeOf((*MockController)(nil).UpdateAccountPeers), ctx, accountID)
} }

View File

@@ -62,7 +62,7 @@ func (a *MockAccountManager) GetDeletePeerCalls() int {
return a.deletePeerCalls return a.deletePeerCalls
} }
func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string, reason types.UpdateReason) { func (a *MockAccountManager) BufferUpdateAccountPeers(ctx context.Context, accountID string) {
a.mu.Lock() a.mu.Lock()
defer a.mu.Unlock() defer a.mu.Unlock()
if a.bufferUpdateCalls == nil { if a.bufferUpdateCalls == nil {
@@ -248,7 +248,7 @@ func TestCleanupSchedulingBehaviorIsBatched(t *testing.T) {
return err return err
} }
} }
mockAM.BufferUpdateAccountPeers(ctx, accountID, types.UpdateReason{}) mockAM.BufferUpdateAccountPeers(ctx, accountID)
return nil return nil
}). }).
Times(1) Times(1)

View File

@@ -20,7 +20,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -179,7 +178,7 @@ func (m *managerImpl) DeletePeers(ctx context.Context, accountID string, peerIDs
} }
} }
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourcePeer, Operation: types.UpdateOperationDelete}) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }

View File

@@ -11,9 +11,9 @@ import (
// Manager defines the interface for proxy operations // Manager defines the interface for proxy operations
type Manager interface { type Manager interface {
Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error
Disconnect(ctx context.Context, proxyID, sessionID string) error Disconnect(ctx context.Context, proxyID string) error
Heartbeat(ctx context.Context, p *Proxy) error Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
GetActiveClusterAddresses(ctx context.Context) ([]string, error) GetActiveClusterAddresses(ctx context.Context) ([]string, error)
GetActiveClusters(ctx context.Context) ([]Cluster, error) GetActiveClusters(ctx context.Context) ([]Cluster, error)
ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool ClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool

View File

@@ -13,8 +13,7 @@ import (
// store defines the interface for proxy persistence operations // store defines the interface for proxy persistence operations
type store interface { type store interface {
SaveProxy(ctx context.Context, p *proxy.Proxy) error SaveProxy(ctx context.Context, p *proxy.Proxy) error
DisconnectProxy(ctx context.Context, proxyID, sessionID string) error UpdateProxyHeartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error
UpdateProxyHeartbeat(ctx context.Context, p *proxy.Proxy) error
GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error) GetActiveProxyClusterAddresses(ctx context.Context) ([]string, error)
GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error) GetActiveProxyClusters(ctx context.Context) ([]proxy.Cluster, error)
GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool GetClusterSupportsCustomPorts(ctx context.Context, clusterAddr string) *bool
@@ -44,7 +43,7 @@ func NewManager(store store, meter metric.Meter) (*Manager, error) {
// Connect registers a new proxy connection in the database. // Connect registers a new proxy connection in the database.
// capabilities may be nil for old proxies that do not report them. // capabilities may be nil for old proxies that do not report them.
func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) (*proxy.Proxy, error) { func (m Manager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *proxy.Capabilities) error {
now := time.Now() now := time.Now()
var caps proxy.Capabilities var caps proxy.Capabilities
if capabilities != nil { if capabilities != nil {
@@ -52,7 +51,6 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
} }
p := &proxy.Proxy{ p := &proxy.Proxy{
ID: proxyID, ID: proxyID,
SessionID: sessionID,
ClusterAddress: clusterAddress, ClusterAddress: clusterAddress,
IPAddress: ipAddress, IPAddress: ipAddress,
LastSeen: now, LastSeen: now,
@@ -63,42 +61,48 @@ func (m Manager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress
if err := m.store.SaveProxy(ctx, p); err != nil { if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err) log.WithContext(ctx).Errorf("failed to register proxy %s: %v", proxyID, err)
return nil, err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
"sessionID": sessionID,
"clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return p, nil
}
// Disconnect marks a proxy as disconnected in the database.
func (m Manager) Disconnect(ctx context.Context, proxyID, sessionID string) error {
if err := m.store.DisconnectProxy(ctx, proxyID, sessionID); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s session %s: %v", proxyID, sessionID, err)
return err return err
} }
log.WithContext(ctx).WithFields(log.Fields{ log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID, "proxyID": proxyID,
"sessionID": sessionID, "clusterAddress": clusterAddress,
"ipAddress": ipAddress,
}).Info("proxy connected")
return nil
}
// Disconnect marks a proxy as disconnected in the database
func (m Manager) Disconnect(ctx context.Context, proxyID string) error {
now := time.Now()
p := &proxy.Proxy{
ID: proxyID,
Status: "disconnected",
DisconnectedAt: &now,
LastSeen: now,
}
if err := m.store.SaveProxy(ctx, p); err != nil {
log.WithContext(ctx).Errorf("failed to disconnect proxy %s: %v", proxyID, err)
return err
}
log.WithContext(ctx).WithFields(log.Fields{
"proxyID": proxyID,
}).Info("proxy disconnected") }).Info("proxy disconnected")
return nil return nil
} }
// Heartbeat updates the proxy's last seen timestamp. // Heartbeat updates the proxy's last seen timestamp
func (m Manager) Heartbeat(ctx context.Context, p *proxy.Proxy) error { func (m Manager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
if err := m.store.UpdateProxyHeartbeat(ctx, p); err != nil { if err := m.store.UpdateProxyHeartbeat(ctx, proxyID, clusterAddress, ipAddress); err != nil {
log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", p.ID, err) log.WithContext(ctx).Debugf("failed to update proxy %s heartbeat: %v", proxyID, err)
return err return err
} }
log.WithContext(ctx).Tracef("updated heartbeat for proxy %s session %s", p.ID, p.SessionID) log.WithContext(ctx).Tracef("updated heartbeat for proxy %s", proxyID)
m.metrics.IncrementProxyHeartbeatCount() m.metrics.IncrementProxyHeartbeatCount()
return nil return nil
} }

View File

@@ -93,32 +93,31 @@ func (mr *MockManagerMockRecorder) ClusterSupportsCrowdSec(ctx, clusterAddr inte
} }
// Connect mocks base method. // Connect mocks base method.
func (m *MockManager) Connect(ctx context.Context, proxyID, sessionID, clusterAddress, ipAddress string, capabilities *Capabilities) (*Proxy, error) { func (m *MockManager) Connect(ctx context.Context, proxyID, clusterAddress, ipAddress string, capabilities *Capabilities) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Connect", ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) ret := m.ctrl.Call(m, "Connect", ctx, proxyID, clusterAddress, ipAddress, capabilities)
ret0, _ := ret[0].(*Proxy) ret0, _ := ret[0].(error)
ret1, _ := ret[1].(error) return ret0
return ret0, ret1
} }
// Connect indicates an expected call of Connect. // Connect indicates an expected call of Connect.
func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Connect(ctx, proxyID, clusterAddress, ipAddress, capabilities interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, sessionID, clusterAddress, ipAddress, capabilities) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Connect", reflect.TypeOf((*MockManager)(nil).Connect), ctx, proxyID, clusterAddress, ipAddress, capabilities)
} }
// Disconnect mocks base method. // Disconnect mocks base method.
func (m *MockManager) Disconnect(ctx context.Context, proxyID, sessionID string) error { func (m *MockManager) Disconnect(ctx context.Context, proxyID string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID, sessionID) ret := m.ctrl.Call(m, "Disconnect", ctx, proxyID)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Disconnect indicates an expected call of Disconnect. // Disconnect indicates an expected call of Disconnect.
func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID, sessionID interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Disconnect(ctx, proxyID interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID, sessionID) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Disconnect", reflect.TypeOf((*MockManager)(nil).Disconnect), ctx, proxyID)
} }
// GetActiveClusterAddresses mocks base method. // GetActiveClusterAddresses mocks base method.
@@ -152,17 +151,17 @@ func (mr *MockManagerMockRecorder) GetActiveClusters(ctx interface{}) *gomock.Ca
} }
// Heartbeat mocks base method. // Heartbeat mocks base method.
func (m *MockManager) Heartbeat(ctx context.Context, p *Proxy) error { func (m *MockManager) Heartbeat(ctx context.Context, proxyID, clusterAddress, ipAddress string) error {
m.ctrl.T.Helper() m.ctrl.T.Helper()
ret := m.ctrl.Call(m, "Heartbeat", ctx, p) ret := m.ctrl.Call(m, "Heartbeat", ctx, proxyID, clusterAddress, ipAddress)
ret0, _ := ret[0].(error) ret0, _ := ret[0].(error)
return ret0 return ret0
} }
// Heartbeat indicates an expected call of Heartbeat. // Heartbeat indicates an expected call of Heartbeat.
func (mr *MockManagerMockRecorder) Heartbeat(ctx, p interface{}) *gomock.Call { func (mr *MockManagerMockRecorder) Heartbeat(ctx, proxyID, clusterAddress, ipAddress interface{}) *gomock.Call {
mr.mock.ctrl.T.Helper() mr.mock.ctrl.T.Helper()
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, p) return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Heartbeat", reflect.TypeOf((*MockManager)(nil).Heartbeat), ctx, proxyID, clusterAddress, ipAddress)
} }
// MockController is a mock of Controller interface. // MockController is a mock of Controller interface.

View File

@@ -18,13 +18,12 @@ type Capabilities struct {
// Proxy represents a reverse proxy instance // Proxy represents a reverse proxy instance
type Proxy struct { type Proxy struct {
ID string `gorm:"primaryKey;type:varchar(255)"` ID string `gorm:"primaryKey;type:varchar(255)"`
SessionID string `gorm:"type:varchar(36)"`
ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"` ClusterAddress string `gorm:"type:varchar(255);not null;index:idx_proxy_cluster_status"`
IPAddress string `gorm:"type:varchar(45)"` IPAddress string `gorm:"type:varchar(45)"`
LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"` LastSeen time.Time `gorm:"not null;index:idx_proxy_last_seen"`
ConnectedAt *time.Time ConnectedAt *time.Time
DisconnectedAt *time.Time DisconnectedAt *time.Time
Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"` Status string `gorm:"type:varchar(20);not null;index:idx_proxy_cluster_status"`
Capabilities Capabilities `gorm:"embedded"` Capabilities Capabilities `gorm:"embedded"`
CreatedAt time.Time CreatedAt time.Time
UpdatedAt time.Time UpdatedAt time.Time

View File

@@ -85,7 +85,7 @@ func setupL4Test(t *testing.T, customPortsSupported *bool) (*Manager, store.Stor
accountMgr := &mock_server.MockAccountManager{ accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}, },

View File

@@ -25,7 +25,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -232,7 +231,7 @@ func (m *Manager) CreateService(ctx context.Context, accountID, userID string, s
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate}) m.accountManager.UpdateAccountPeers(ctx, accountID)
return s, nil return s, nil
} }
@@ -516,7 +515,7 @@ func (m *Manager) UpdateService(ctx context.Context, accountID, userID string, s
} }
m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo) m.sendServiceUpdateNotifications(ctx, accountID, service, updateInfo)
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate}) m.accountManager.UpdateAccountPeers(ctx, accountID)
return service, nil return service, nil
} }
@@ -820,7 +819,7 @@ func (m *Manager) DeleteService(ctx context.Context, accountID, userID, serviceI
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
@@ -861,7 +860,7 @@ func (m *Manager) DeleteAllServices(ctx context.Context, accountID, userID strin
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", oidcCfg), svc.ProxyCluster)
} }
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
@@ -917,7 +916,7 @@ func (m *Manager) ReloadService(ctx context.Context, accountID, serviceID string
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, s.ToProtoMapping(service.Update, "", m.proxyController.GetOIDCValidationConfig()), s.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationUpdate}) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
@@ -1099,7 +1098,7 @@ func (m *Manager) CreateServiceFromPeer(ctx context.Context, accountID, peerID s
} }
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Create, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationCreate}) m.accountManager.UpdateAccountPeers(ctx, accountID)
serviceURL := "https://" + svc.Domain serviceURL := "https://" + svc.Domain
if service.IsL4Protocol(svc.Mode) { if service.IsL4Protocol(svc.Mode) {
@@ -1211,7 +1210,7 @@ func (m *Manager) deletePeerService(ctx context.Context, accountID, peerID, serv
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }
@@ -1262,7 +1261,7 @@ func (m *Manager) deleteExpiredPeerService(ctx context.Context, accountID, peerI
meta := addPeerInfoToEventMeta(svc.EventMeta(), peer) meta := addPeerInfoToEventMeta(svc.EventMeta(), peer)
m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta) m.accountManager.StoreEvent(ctx, peerID, serviceID, accountID, activity.PeerServiceExposeExpired, meta)
m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster) m.proxyController.SendServiceUpdateToCluster(ctx, accountID, svc.ToProtoMapping(service.Delete, "", m.proxyController.GetOIDCValidationConfig()), svc.ProxyCluster)
m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceService, Operation: types.UpdateOperationDelete}) m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }

View File

@@ -447,7 +447,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
storedActivity = activityID.(activity.Activity) storedActivity = activityID.(activity.Activity)
}, },
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
} }
mockStore.EXPECT(). mockStore.EXPECT().
@@ -549,7 +549,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) { StoreEventFunc: func(_ context.Context, _, _, _ string, activityID activity.ActivityDescriber, _ map[string]any) {
storedActivity = activityID.(activity.Activity) storedActivity = activityID.(activity.Activity)
}, },
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
} }
mockStore.EXPECT(). mockStore.EXPECT().
@@ -593,7 +593,7 @@ func TestDeletePeerService_SourcePeerValidation(t *testing.T) {
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) { StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, meta map[string]any) {
storedMeta = meta storedMeta = meta
}, },
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
} }
mockStore.EXPECT(). mockStore.EXPECT().
@@ -704,7 +704,7 @@ func setupIntegrationTest(t *testing.T) (*Manager, store.Store) {
accountMgr := &mock_server.MockAccountManager{ accountMgr := &mock_server.MockAccountManager{
StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {}, StoreEventFunc: func(_ context.Context, _, _, _ string, _ activity.ActivityDescriber, _ map[string]any) {},
UpdateAccountPeersFunc: func(_ context.Context, _ string, _ types.UpdateReason) {}, UpdateAccountPeersFunc: func(_ context.Context, _ string) {},
GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) { GetGroupByNameFunc: func(ctx context.Context, groupName, accountID, userID string) (*types.Group, error) {
return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName) return testStore.GetGroupByName(ctx, store.LockingStrengthNone, accountID, groupName)
}, },
@@ -1173,7 +1173,7 @@ func TestDeleteService_DeletesTargets(t *testing.T) {
mockAcct.EXPECT(). mockAcct.EXPECT().
StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any()) StoreEvent(ctx, userID, service.ID, accountID, activity.ServiceDeleted, gomock.Any())
mockAcct.EXPECT(). mockAcct.EXPECT().
UpdateAccountPeers(ctx, accountID, gomock.Any()) UpdateAccountPeers(ctx, accountID)
err = mgr.DeleteService(ctx, accountID, userID, service.ID) err = mgr.DeleteService(ctx, accountID, userID, service.ID)
require.NoError(t, err) require.NoError(t, err)

View File

@@ -11,7 +11,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -145,7 +144,7 @@ func (m *managerImpl) UpdateZone(ctx context.Context, accountID, userID string,
m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta()) m.accountManager.StoreEvent(ctx, userID, zone.ID, accountID, activity.DNSZoneUpdated, zone.EventMeta())
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationUpdate}) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return zone, nil return zone, nil
} }
@@ -207,7 +206,7 @@ func (m *managerImpl) DeleteZone(ctx context.Context, accountID, userID, zoneID
event() event()
} }
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZone, Operation: types.UpdateOperationDelete}) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }

View File

@@ -13,7 +13,6 @@ import (
"github.com/netbirdio/netbird/management/server/permissions/modules" "github.com/netbirdio/netbird/management/server/permissions/modules"
"github.com/netbirdio/netbird/management/server/permissions/operations" "github.com/netbirdio/netbird/management/server/permissions/operations"
"github.com/netbirdio/netbird/management/server/store" "github.com/netbirdio/netbird/management/server/store"
"github.com/netbirdio/netbird/management/server/types"
"github.com/netbirdio/netbird/shared/management/status" "github.com/netbirdio/netbird/shared/management/status"
) )
@@ -96,7 +95,7 @@ func (m *managerImpl) CreateRecord(ctx context.Context, accountID, userID, zoneI
meta := record.EventMeta(zone.ID, zone.Name) meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta) m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordCreated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationCreate}) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil return record, nil
} }
@@ -155,7 +154,7 @@ func (m *managerImpl) UpdateRecord(ctx context.Context, accountID, userID, zoneI
meta := record.EventMeta(zone.ID, zone.Name) meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta) m.accountManager.StoreEvent(ctx, userID, record.ID, accountID, activity.DNSRecordUpdated, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationUpdate}) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return record, nil return record, nil
} }
@@ -202,7 +201,7 @@ func (m *managerImpl) DeleteRecord(ctx context.Context, accountID, userID, zoneI
meta := record.EventMeta(zone.ID, zone.Name) meta := record.EventMeta(zone.ID, zone.Name)
m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta) m.accountManager.StoreEvent(ctx, userID, recordID, accountID, activity.DNSRecordDeleted, meta)
go m.accountManager.UpdateAccountPeers(ctx, accountID, types.UpdateReason{Resource: types.UpdateResourceZoneRecord, Operation: types.UpdateOperationDelete}) go m.accountManager.UpdateAccountPeers(ctx, accountID)
return nil return nil
} }

View File

@@ -173,7 +173,7 @@ func (s *BaseServer) GRPCServer() *grpc.Server {
} }
gRPCAPIHandler := grpc.NewServer(gRPCOpts...) gRPCAPIHandler := grpc.NewServer(gRPCOpts...)
srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider(), s.SessionStore()) srv, err := nbgrpc.NewServer(s.Config, s.AccountManager(), s.SettingsManager(), s.JobManager(), s.SecretsManager(), s.Metrics(), s.AuthManager(), s.IntegratedValidator(), s.NetworkMapController(), s.OAuthConfigProvider())
if err != nil { if err != nil {
log.Fatalf("failed to create management server: %v", err) log.Fatalf("failed to create management server: %v", err)
} }

View File

@@ -6,7 +6,6 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/netbirdio/management-integrations/integrations" "github.com/netbirdio/management-integrations/integrations"
"github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy" "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy"
proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager" proxymanager "github.com/netbirdio/netbird/management/internals/modules/reverseproxy/proxy/manager"
@@ -67,12 +66,6 @@ func (s *BaseServer) SecretsManager() grpc.SecretsManager {
}) })
} }
func (s *BaseServer) SessionStore() *auth.SessionStore {
return Create(s, func() *auth.SessionStore {
return auth.NewSessionStore(s.CacheStore())
})
}
func (s *BaseServer) AuthManager() auth.Manager { func (s *BaseServer) AuthManager() auth.Manager {
audiences := s.Config.GetAuthAudiences() audiences := s.Config.GetAuthAudiences()
audience := s.Config.HttpConfig.AuthAudience audience := s.Config.HttpConfig.AuthAudience

Some files were not shown because too many files have changed in this diff Show More