mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-18 08:16:39 +00:00
Compare commits
92 Commits
feature/fl
...
feature/ex
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
55781d1e9d | ||
|
|
ab77508950 | ||
|
|
b9462f5c6b | ||
|
|
5ffaa5cdd6 | ||
|
|
a1858a9cb7 | ||
|
|
212b34f639 | ||
|
|
af8eaa23e2 | ||
|
|
f0eed50678 | ||
|
|
19d94c6158 | ||
|
|
628eb56073 | ||
|
|
a590c38d8b | ||
|
|
4e149c9222 | ||
|
|
59f5b34280 | ||
|
|
dff06d0898 | ||
|
|
80a8816b1d | ||
|
|
387e374e4b | ||
|
|
3e6baea405 | ||
|
|
fe9b844511 | ||
|
|
2e1aa497d2 | ||
|
|
529c0314f8 | ||
|
|
d86875aeac | ||
|
|
f80fe506d5 | ||
|
|
967c6f3cd3 | ||
|
|
e50e124e70 | ||
|
|
c545689448 | ||
|
|
8f389fef19 | ||
|
|
d3d6a327e0 | ||
|
|
b5489d4986 | ||
|
|
7a23c57cf8 | ||
|
|
11f891220e | ||
|
|
5585adce18 | ||
|
|
f884299823 | ||
|
|
15aa6bae1b | ||
|
|
11eb725ac8 | ||
|
|
30c02ab78c | ||
|
|
3acd86e346 | ||
|
|
5c20f13c48 | ||
|
|
e6587b071d | ||
|
|
85451ab4cd | ||
|
|
a7f3ba03eb | ||
|
|
4f0a3a77ad | ||
|
|
44655ca9b5 | ||
|
|
e601278117 | ||
|
|
8e7b016be2 | ||
|
|
9e01ea7aae | ||
|
|
cfc7ec8bb9 | ||
|
|
b3bbc0e5c6 | ||
|
|
d7c8e37ff4 | ||
|
|
05b66e73bc | ||
|
|
01ceedac89 | ||
|
|
403babd433 | ||
|
|
47133031e5 | ||
|
|
82da606886 | ||
|
|
bbe5ae2145 | ||
|
|
0b21498b39 | ||
|
|
0ca59535f1 | ||
|
|
59c77d0658 | ||
|
|
333e045099 | ||
|
|
c2c4d9d336 | ||
|
|
9a6a72e88e | ||
|
|
afe6d9fca4 | ||
|
|
ef82905526 | ||
|
|
d18747e846 | ||
|
|
f341d69314 | ||
|
|
327142837c | ||
|
|
f8c0321aee | ||
|
|
89115ff76a | ||
|
|
63c83aa8d2 | ||
|
|
37f025c966 | ||
|
|
4a54f0d670 | ||
|
|
98890a29e3 | ||
|
|
9d123ec059 | ||
|
|
5d171f181a | ||
|
|
22f878b3b7 | ||
|
|
44ef1a18dd | ||
|
|
2b98dc4e52 | ||
|
|
2a26cb4567 | ||
|
|
5ca1b64328 | ||
|
|
36752a8cbb | ||
|
|
f117fc7509 | ||
|
|
fc6b93ae59 | ||
|
|
564fa4ab04 | ||
|
|
a6db88fbd2 | ||
|
|
4b5294e596 | ||
|
|
a322dce42a | ||
|
|
d1ead2265b | ||
|
|
bbca74476e | ||
|
|
318cf59d66 | ||
|
|
e9b2a6e808 | ||
|
|
2dbdb5c1a7 | ||
|
|
2cdab6d7b7 | ||
|
|
e49c0e8862 |
14
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
14
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
||||
blank_issues_enabled: true
|
||||
contact_links:
|
||||
- name: Community Support
|
||||
url: https://forum.netbird.io/
|
||||
about: Community support forum
|
||||
- name: Cloud Support
|
||||
url: https://docs.netbird.io/help/report-bug-issues
|
||||
about: Contact us for support
|
||||
- name: Client/Connection Troubleshooting
|
||||
url: https://docs.netbird.io/help/troubleshooting-client
|
||||
about: See our client troubleshooting guide for help addressing common issues
|
||||
- name: Self-host Troubleshooting
|
||||
url: https://docs.netbird.io/selfhosted/troubleshooting
|
||||
about: See our self-host troubleshooting guide for help addressing common issues
|
||||
37
.github/workflows/golang-test-linux.yml
vendored
37
.github/workflows/golang-test-linux.yml
vendored
@@ -409,12 +409,19 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
@@ -497,15 +504,18 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
@@ -586,15 +596,18 @@ jobs:
|
||||
run: git --no-pager diff --exit-code
|
||||
|
||||
- name: Login to Docker hub
|
||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||
uses: docker/login-action@v1
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
uses: docker/login-action@v3
|
||||
with:
|
||||
username: ${{ secrets.DOCKER_USER }}
|
||||
password: ${{ secrets.DOCKER_TOKEN }}
|
||||
|
||||
- name: download mysql image
|
||||
if: matrix.store == 'mysql'
|
||||
run: docker pull mlsmaycon/warmed-mysql:8
|
||||
- name: docker login for root user
|
||||
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||
env:
|
||||
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||
|
||||
- name: Test
|
||||
run: |
|
||||
|
||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
||||
- name: codespell
|
||||
uses: codespell-project/actions-codespell@v2
|
||||
with:
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver
|
||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te
|
||||
skip: go.mod,go.sum,**/proxy/web/**
|
||||
golangci:
|
||||
strategy:
|
||||
|
||||
51
.github/workflows/pr-title-check.yml
vendored
Normal file
51
.github/workflows/pr-title-check.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
||||
name: PR Title Check
|
||||
|
||||
on:
|
||||
pull_request:
|
||||
types: [opened, edited, synchronize, reopened]
|
||||
|
||||
jobs:
|
||||
check-title:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Validate PR title prefix
|
||||
uses: actions/github-script@v7
|
||||
with:
|
||||
script: |
|
||||
const title = context.payload.pull_request.title;
|
||||
const allowedTags = [
|
||||
'management',
|
||||
'client',
|
||||
'signal',
|
||||
'proxy',
|
||||
'relay',
|
||||
'misc',
|
||||
'infrastructure',
|
||||
'self-hosted',
|
||||
'doc',
|
||||
];
|
||||
|
||||
const pattern = /^\[([^\]]+)\]\s+.+/;
|
||||
const match = title.match(pattern);
|
||||
|
||||
if (!match) {
|
||||
core.setFailed(
|
||||
`PR title must start with a tag in brackets.\n` +
|
||||
`Example: [client] fix something\n` +
|
||||
`Allowed tags: ${allowedTags.join(', ')}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
|
||||
|
||||
const invalid = tags.filter(t => !allowedTags.includes(t));
|
||||
if (invalid.length > 0) {
|
||||
core.setFailed(
|
||||
`Invalid tag(s): ${invalid.join(', ')}\n` +
|
||||
`Allowed tags: ${allowedTags.join(', ')}`
|
||||
);
|
||||
return;
|
||||
}
|
||||
|
||||
console.log(`Valid PR title tags: [${tags.join(', ')}]`);
|
||||
86
.github/workflows/release.yml
vendored
86
.github/workflows/release.yml
vendored
@@ -10,7 +10,7 @@ on:
|
||||
|
||||
env:
|
||||
SIGN_PIPE_VER: "v0.1.1"
|
||||
GORELEASER_VER: "v2.3.2"
|
||||
GORELEASER_VER: "v2.14.3"
|
||||
PRODUCT_NAME: "NetBird"
|
||||
COPYRIGHT: "NetBird GmbH"
|
||||
|
||||
@@ -169,6 +169,13 @@ jobs:
|
||||
- name: Install OS build dependencies
|
||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||
|
||||
- name: Decode GPG signing key
|
||||
env:
|
||||
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||
run: |
|
||||
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
|
||||
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
|
||||
|
||||
- name: Install goversioninfo
|
||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||
- name: Generate windows syso amd64
|
||||
@@ -186,18 +193,54 @@ jobs:
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
- name: Tag and push PR images (amd64 only)
|
||||
if: github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository
|
||||
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||
- name: Verify RPM signatures
|
||||
run: |
|
||||
PR_TAG="pr-${{ github.event.pull_request.number }}"
|
||||
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||
dnf install -y -q rpm-sign curl >/dev/null 2>&1
|
||||
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
|
||||
rpm --import /tmp/rpm-pub.key
|
||||
echo "=== Verifying RPM signatures ==="
|
||||
for rpm_file in /dist/*amd64*.rpm; do
|
||||
[ -f "$rpm_file" ] || continue
|
||||
echo "--- $(basename $rpm_file) ---"
|
||||
rpm -K "$rpm_file"
|
||||
done
|
||||
'
|
||||
- name: Clean up GPG key
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: Tag and push images (amd64 only)
|
||||
if: |
|
||||
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
||||
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
||||
run: |
|
||||
resolve_tags() {
|
||||
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||
echo "pr-${{ github.event.pull_request.number }}"
|
||||
else
|
||||
echo "main sha-$(git rev-parse --short HEAD)"
|
||||
fi
|
||||
}
|
||||
|
||||
tag_and_push() {
|
||||
local src="$1" img_name tag dst
|
||||
img_name="${src%%:*}"
|
||||
for tag in $(resolve_tags); do
|
||||
dst="${img_name}:${tag}"
|
||||
echo "Tagging ${src} -> ${dst}"
|
||||
docker tag "$src" "$dst"
|
||||
docker push "$dst"
|
||||
done
|
||||
}
|
||||
|
||||
export -f tag_and_push resolve_tags
|
||||
|
||||
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||
grep '^ghcr.io/' | while read -r SRC; do
|
||||
IMG_NAME="${SRC%%:*}"
|
||||
DST="${IMG_NAME}:${PR_TAG}"
|
||||
echo "Tagging ${SRC} -> ${DST}"
|
||||
docker tag "$SRC" "$DST"
|
||||
docker push "$DST"
|
||||
tag_and_push "$SRC"
|
||||
done
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
@@ -265,6 +308,13 @@ jobs:
|
||||
- name: Install dependencies
|
||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||
|
||||
- name: Decode GPG signing key
|
||||
env:
|
||||
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||
run: |
|
||||
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
|
||||
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
|
||||
|
||||
- name: Install LLVM-MinGW for ARM64 cross-compilation
|
||||
run: |
|
||||
cd /tmp
|
||||
@@ -289,6 +339,24 @@ jobs:
|
||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||
- name: Verify RPM signatures
|
||||
run: |
|
||||
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||
dnf install -y -q rpm-sign curl >/dev/null 2>&1
|
||||
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
|
||||
rpm --import /tmp/rpm-pub.key
|
||||
echo "=== Verifying RPM signatures ==="
|
||||
for rpm_file in /dist/*.rpm; do
|
||||
[ -f "$rpm_file" ] || continue
|
||||
echo "--- $(basename $rpm_file) ---"
|
||||
rpm -K "$rpm_file"
|
||||
done
|
||||
'
|
||||
- name: Clean up GPG key
|
||||
if: always()
|
||||
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||
- name: upload non tags for debug purposes
|
||||
uses: actions/upload-artifact@v4
|
||||
with:
|
||||
|
||||
@@ -171,13 +171,12 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
id: netbird-deb
|
||||
id: netbird_deb
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
- netbird
|
||||
formats:
|
||||
- deb
|
||||
|
||||
scripts:
|
||||
postinstall: "release_files/post_install.sh"
|
||||
preremove: "release_files/pre_remove.sh"
|
||||
@@ -185,16 +184,18 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client.
|
||||
homepage: https://netbird.io/
|
||||
id: netbird-rpm
|
||||
id: netbird_rpm
|
||||
bindir: /usr/bin
|
||||
builds:
|
||||
- netbird
|
||||
formats:
|
||||
- rpm
|
||||
|
||||
scripts:
|
||||
postinstall: "release_files/post_install.sh"
|
||||
preremove: "release_files/pre_remove.sh"
|
||||
rpm:
|
||||
signature:
|
||||
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||
dockers:
|
||||
- image_templates:
|
||||
- netbirdio/netbird:{{ .Version }}-amd64
|
||||
@@ -876,7 +877,7 @@ brews:
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird-deb
|
||||
- netbird_deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
@@ -884,7 +885,7 @@ uploads:
|
||||
|
||||
- name: yum
|
||||
ids:
|
||||
- netbird-rpm
|
||||
- netbird_rpm
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||
username: dev@wiretrustee.com
|
||||
|
||||
@@ -61,7 +61,7 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
homepage: https://netbird.io/
|
||||
id: netbird-ui-deb
|
||||
id: netbird_ui_deb
|
||||
package_name: netbird-ui
|
||||
builds:
|
||||
- netbird-ui
|
||||
@@ -80,7 +80,7 @@ nfpms:
|
||||
- maintainer: Netbird <dev@netbird.io>
|
||||
description: Netbird client UI.
|
||||
homepage: https://netbird.io/
|
||||
id: netbird-ui-rpm
|
||||
id: netbird_ui_rpm
|
||||
package_name: netbird-ui
|
||||
builds:
|
||||
- netbird-ui
|
||||
@@ -95,11 +95,14 @@ nfpms:
|
||||
dst: /usr/share/pixmaps/netbird.png
|
||||
dependencies:
|
||||
- netbird
|
||||
rpm:
|
||||
signature:
|
||||
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||
|
||||
uploads:
|
||||
- name: debian
|
||||
ids:
|
||||
- netbird-ui-deb
|
||||
- netbird_ui_deb
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||
username: dev@wiretrustee.com
|
||||
@@ -107,7 +110,7 @@ uploads:
|
||||
|
||||
- name: yum
|
||||
ids:
|
||||
- netbird-ui-rpm
|
||||
- netbird_ui_rpm
|
||||
mode: archive
|
||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||
username: dev@wiretrustee.com
|
||||
|
||||
@@ -126,6 +126,7 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how
|
||||
### Community projects
|
||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||
|
||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||
|
||||
@@ -4,7 +4,7 @@
|
||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||
|
||||
FROM alpine:3.23.2
|
||||
FROM alpine:3.23.3
|
||||
# iproute2: busybox doesn't display ip rules properly
|
||||
RUN apk add --no-cache \
|
||||
bash \
|
||||
|
||||
@@ -124,7 +124,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
}
|
||||
|
||||
@@ -157,7 +157,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
||||
|
||||
// todo do not throw error in case of cancelled context
|
||||
ctx = internal.CtxInitState(ctx)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||
}
|
||||
|
||||
|
||||
280
client/cmd/expose.go
Normal file
280
client/cmd/expose.go
Normal file
@@ -0,0 +1,280 @@
|
||||
package cmd
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
"os"
|
||||
"os/signal"
|
||||
"regexp"
|
||||
"strconv"
|
||||
"strings"
|
||||
"syscall"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||
|
||||
var (
|
||||
exposePin string
|
||||
exposePassword string
|
||||
exposeUserGroups []string
|
||||
exposeDomain string
|
||||
exposeNamePrefix string
|
||||
exposeProtocol string
|
||||
exposeExternalPort uint16
|
||||
)
|
||||
|
||||
var exposeCmd = &cobra.Command{
|
||||
Use: "expose <port>",
|
||||
Short: "Expose a local port via the NetBird reverse proxy",
|
||||
Args: cobra.ExactArgs(1),
|
||||
Example: ` netbird expose --with-password safe-pass 8080
|
||||
netbird expose --protocol tcp 5432
|
||||
netbird expose --protocol tcp --with-external-port 5433 5432
|
||||
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
|
||||
RunE: exposeFn,
|
||||
}
|
||||
|
||||
func init() {
|
||||
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
|
||||
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
|
||||
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
|
||||
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
|
||||
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
|
||||
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
|
||||
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
|
||||
}
|
||||
|
||||
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
|
||||
func isClusterProtocol(protocol string) bool {
|
||||
switch strings.ToLower(protocol) {
|
||||
case "tcp", "udp", "tls":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
|
||||
// where domain display doesn't apply. TLS uses SNI so it has a domain.
|
||||
func isPortBasedProtocol(protocol string) bool {
|
||||
switch strings.ToLower(protocol) {
|
||||
case "tcp", "udp":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
// extractPort returns the port portion of a URL like "tcp://host:12345", or
|
||||
// falls back to the given default formatted as a string.
|
||||
func extractPort(serviceURL string, fallback uint16) string {
|
||||
u := serviceURL
|
||||
if idx := strings.Index(u, "://"); idx != -1 {
|
||||
u = u[idx+3:]
|
||||
}
|
||||
if i := strings.LastIndex(u, ":"); i != -1 {
|
||||
if p := u[i+1:]; p != "" {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return strconv.FormatUint(uint64(fallback), 10)
|
||||
}
|
||||
|
||||
// resolveExternalPort returns the effective external port, defaulting to the target port.
|
||||
func resolveExternalPort(targetPort uint64) uint16 {
|
||||
if exposeExternalPort != 0 {
|
||||
return exposeExternalPort
|
||||
}
|
||||
return uint16(targetPort)
|
||||
}
|
||||
|
||||
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
|
||||
port, err := strconv.ParseUint(portStr, 10, 32)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
||||
}
|
||||
if port == 0 || port > 65535 {
|
||||
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
|
||||
}
|
||||
|
||||
if !isProtocolValid(exposeProtocol) {
|
||||
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
|
||||
}
|
||||
|
||||
if isClusterProtocol(exposeProtocol) {
|
||||
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
|
||||
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
|
||||
}
|
||||
} else if cmd.Flags().Changed("with-external-port") {
|
||||
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
|
||||
}
|
||||
|
||||
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
|
||||
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-password") && exposePassword == "" {
|
||||
return 0, fmt.Errorf("password cannot be empty")
|
||||
}
|
||||
|
||||
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
|
||||
return 0, fmt.Errorf("user groups cannot be empty")
|
||||
}
|
||||
|
||||
return port, nil
|
||||
}
|
||||
|
||||
func isProtocolValid(exposeProtocol string) bool {
|
||||
switch strings.ToLower(exposeProtocol) {
|
||||
case "http", "https", "tcp", "udp", "tls":
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func exposeFn(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
|
||||
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||
log.Errorf("failed initializing log %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = false
|
||||
|
||||
port, err := validateExposeFlags(cmd, args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
cmd.Root().SilenceUsage = true
|
||||
|
||||
ctx, cancel := context.WithCancel(cmd.Context())
|
||||
defer cancel()
|
||||
|
||||
sigCh := make(chan os.Signal, 1)
|
||||
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||
go func() {
|
||||
<-sigCh
|
||||
cancel()
|
||||
}()
|
||||
|
||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("connect to daemon: %w", err)
|
||||
}
|
||||
defer func() {
|
||||
if err := conn.Close(); err != nil {
|
||||
log.Debugf("failed to close daemon connection: %v", err)
|
||||
}
|
||||
}()
|
||||
|
||||
client := proto.NewDaemonServiceClient(conn)
|
||||
|
||||
protocol, err := toExposeProtocol(exposeProtocol)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
req := &proto.ExposeServiceRequest{
|
||||
Port: uint32(port),
|
||||
Protocol: protocol,
|
||||
Pin: exposePin,
|
||||
Password: exposePassword,
|
||||
UserGroups: exposeUserGroups,
|
||||
Domain: exposeDomain,
|
||||
NamePrefix: exposeNamePrefix,
|
||||
}
|
||||
if isClusterProtocol(exposeProtocol) {
|
||||
req.ListenPort = uint32(resolveExternalPort(port))
|
||||
}
|
||||
|
||||
stream, err := client.ExposeService(ctx, req)
|
||||
if err != nil {
|
||||
return fmt.Errorf("expose service: %w", err)
|
||||
}
|
||||
|
||||
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return waitForExposeEvents(cmd, ctx, stream)
|
||||
}
|
||||
|
||||
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||
switch strings.ToLower(exposeProtocol) {
|
||||
case "http":
|
||||
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||
case "https":
|
||||
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||
case "tcp":
|
||||
return proto.ExposeProtocol_EXPOSE_TCP, nil
|
||||
case "udp":
|
||||
return proto.ExposeProtocol_EXPOSE_UDP, nil
|
||||
case "tls":
|
||||
return proto.ExposeProtocol_EXPOSE_TLS, nil
|
||||
default:
|
||||
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
|
||||
}
|
||||
}
|
||||
|
||||
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||
event, err := stream.Recv()
|
||||
if err != nil {
|
||||
return fmt.Errorf("receive expose event: %w", err)
|
||||
}
|
||||
|
||||
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||
if !ok {
|
||||
return fmt.Errorf("unexpected expose event: %T", event.Event)
|
||||
}
|
||||
printExposeReady(cmd, ready.Ready, port)
|
||||
return nil
|
||||
}
|
||||
|
||||
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
|
||||
cmd.Println("Service exposed successfully!")
|
||||
cmd.Printf(" Name: %s\n", r.ServiceName)
|
||||
if r.ServiceUrl != "" {
|
||||
cmd.Printf(" URL: %s\n", r.ServiceUrl)
|
||||
}
|
||||
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
|
||||
cmd.Printf(" Domain: %s\n", r.Domain)
|
||||
}
|
||||
cmd.Printf(" Protocol: %s\n", exposeProtocol)
|
||||
cmd.Printf(" Internal: %d\n", port)
|
||||
if isClusterProtocol(exposeProtocol) {
|
||||
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
|
||||
}
|
||||
if r.PortAutoAssigned && exposeExternalPort != 0 {
|
||||
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
|
||||
}
|
||||
cmd.Println()
|
||||
cmd.Println("Press Ctrl+C to stop exposing.")
|
||||
}
|
||||
|
||||
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
|
||||
for {
|
||||
_, err := stream.Recv()
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
cmd.Println("\nService stopped.")
|
||||
//nolint:nilerr
|
||||
return nil
|
||||
}
|
||||
if errors.Is(err, io.EOF) {
|
||||
return fmt.Errorf("connection to daemon closed unexpectedly")
|
||||
}
|
||||
return fmt.Errorf("stream error: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -22,6 +22,7 @@ import (
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
|
||||
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
)
|
||||
|
||||
@@ -80,6 +81,15 @@ var (
|
||||
Short: "",
|
||||
Long: "",
|
||||
SilenceUsage: true,
|
||||
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||
SetFlagsFromEnvVars(cmd.Root())
|
||||
|
||||
// Don't resolve for service commands — they create the socket, not connect to it.
|
||||
if !isServiceCmd(cmd) {
|
||||
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
|
||||
}
|
||||
return nil
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
@@ -144,6 +154,7 @@ func init() {
|
||||
rootCmd.AddCommand(forwardingRulesCmd)
|
||||
rootCmd.AddCommand(debugCmd)
|
||||
rootCmd.AddCommand(profileCmd)
|
||||
rootCmd.AddCommand(exposeCmd)
|
||||
|
||||
networksCMD.AddCommand(routesListCmd)
|
||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||
@@ -385,7 +396,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
||||
}
|
||||
|
||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||
@@ -398,3 +408,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||
|
||||
return conn, nil
|
||||
}
|
||||
|
||||
// isServiceCmd returns true if cmd is the "service" command or a child of it.
|
||||
func isServiceCmd(cmd *cobra.Command) bool {
|
||||
for c := cmd; c != nil; c = c.Parent() {
|
||||
if c.Name() == "service" {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
|
||||
|
||||
// Common setup for service control commands
|
||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
||||
SetFlagsFromEnvVars(rootCmd)
|
||||
// rootCmd env vars are already applied by PersistentPreRunE.
|
||||
SetFlagsFromEnvVars(serviceCmd)
|
||||
|
||||
cmd.SetOut(cmd.OutOrStdout())
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -6,7 +6,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
const (
|
||||
|
||||
@@ -7,7 +7,7 @@ import (
|
||||
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
var (
|
||||
|
||||
@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
||||
r := peer.NewRecorder(config.ManagementURL.String())
|
||||
r.GetFullStatus()
|
||||
|
||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||
|
||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
)
|
||||
|
||||
|
||||
@@ -14,6 +14,7 @@ import (
|
||||
"github.com/sirupsen/logrus"
|
||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
"github.com/netbirdio/netbird/client/internal/auth"
|
||||
@@ -21,6 +22,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
)
|
||||
|
||||
@@ -81,6 +83,14 @@ type Options struct {
|
||||
BlockInbound bool
|
||||
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||
WireguardPort *int
|
||||
// MTU is the MTU for the WireGuard interface.
|
||||
// Valid values are in the range 576..8192 bytes.
|
||||
// If non-nil, this value overrides any value stored in the config file.
|
||||
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
|
||||
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
|
||||
MTU *uint16
|
||||
// DNSLabels defines additional DNS labels configured in the peer.
|
||||
DNSLabels []string
|
||||
}
|
||||
|
||||
// validateCredentials checks that exactly one credential type is provided
|
||||
@@ -112,6 +122,12 @@ func New(opts Options) (*Client, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if opts.MTU != nil {
|
||||
if err := iface.ValidateMTU(*opts.MTU); err != nil {
|
||||
return nil, fmt.Errorf("invalid MTU: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
if opts.LogOutput != nil {
|
||||
logrus.SetOutput(opts.LogOutput)
|
||||
}
|
||||
@@ -140,9 +156,14 @@ func New(opts Options) (*Client, error) {
|
||||
}
|
||||
}
|
||||
|
||||
var err error
|
||||
var parsedLabels domain.List
|
||||
if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil {
|
||||
return nil, fmt.Errorf("invalid dns labels: %w", err)
|
||||
}
|
||||
|
||||
t := true
|
||||
var config *profilemanager.Config
|
||||
var err error
|
||||
input := profilemanager.ConfigInput{
|
||||
ConfigPath: opts.ConfigPath,
|
||||
ManagementURL: opts.ManagementURL,
|
||||
@@ -151,6 +172,8 @@ func New(opts Options) (*Client, error) {
|
||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||
BlockInbound: &opts.BlockInbound,
|
||||
WireguardPort: opts.WireguardPort,
|
||||
MTU: opts.MTU,
|
||||
DNSLabels: parsedLabels,
|
||||
}
|
||||
if opts.ConfigPath != "" {
|
||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||
@@ -202,7 +225,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||
return fmt.Errorf("login: %w", err)
|
||||
}
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder, false)
|
||||
client := internal.NewConnectClient(ctx, c.config, c.recorder)
|
||||
client.SetSyncResponsePersistence(true)
|
||||
|
||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||
|
||||
@@ -23,9 +23,10 @@ type Manager struct {
|
||||
|
||||
wgIface iFaceMapper
|
||||
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
ipv4Client *iptables.IPTables
|
||||
aclMgr *aclManager
|
||||
router *router
|
||||
rawSupported bool
|
||||
}
|
||||
|
||||
// iFaceMapper defines subset methods of interface required for manager
|
||||
@@ -84,7 +85,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChain(); err != nil {
|
||||
return fmt.Errorf("init notrack chain: %w", err)
|
||||
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
// persist early to ensure cleanup of chains
|
||||
@@ -318,6 +319,10 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||
m.mutex.Lock()
|
||||
defer m.mutex.Unlock()
|
||||
|
||||
if !m.rawSupported {
|
||||
return fmt.Errorf("raw table not available")
|
||||
}
|
||||
|
||||
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||
|
||||
@@ -375,12 +380,16 @@ func (m *Manager) initNoTrackChain() error {
|
||||
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||
}
|
||||
|
||||
m.rawSupported = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (m *Manager) cleanupNoTrackChain() error {
|
||||
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||
if err != nil {
|
||||
if !m.rawSupported {
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("check chain exists: %w", err)
|
||||
}
|
||||
if !exists {
|
||||
@@ -401,6 +410,7 @@ func (m *Manager) cleanupNoTrackChain() error {
|
||||
return fmt.Errorf("clear and delete chain: %w", err)
|
||||
}
|
||||
|
||||
m.rawSupported = false
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -95,7 +95,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
||||
}
|
||||
|
||||
if err := m.initNoTrackChains(workTable); err != nil {
|
||||
return fmt.Errorf("init notrack chains: %w", err)
|
||||
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
||||
}
|
||||
|
||||
stateManager.RegisterState(&ShutdownState{})
|
||||
|
||||
@@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
||||
|
||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) {
|
||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||
// for js, the outer websocket layer takes care of tls
|
||||
if tlsEnabled && runtime.GOOS != "js" {
|
||||
@@ -46,9 +46,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||
defer cancel()
|
||||
|
||||
conn, err := grpc.DialContext(
|
||||
connCtx,
|
||||
addr,
|
||||
opts := []grpc.DialOption{
|
||||
transportOption,
|
||||
WithCustomDialer(tlsEnabled, component),
|
||||
grpc.WithBlock(),
|
||||
@@ -56,7 +54,10 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
||||
Time: 30 * time.Second,
|
||||
Timeout: 10 * time.Second,
|
||||
}),
|
||||
)
|
||||
}
|
||||
opts = append(opts, extraOpts...)
|
||||
|
||||
conn, err := grpc.DialContext(connCtx, addr, opts...)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("dial context: %w", err)
|
||||
}
|
||||
|
||||
@@ -5,20 +5,18 @@ package configurer
|
||||
import (
|
||||
"net"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.zx2c4.com/wireguard/ipc"
|
||||
)
|
||||
|
||||
func openUAPI(deviceName string) (net.Listener, error) {
|
||||
uapiSock, err := ipc.UAPIOpen(deviceName)
|
||||
if err != nil {
|
||||
log.Errorf("failed to open uapi socket: %v", err)
|
||||
return nil, err
|
||||
}
|
||||
|
||||
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
||||
if err != nil {
|
||||
log.Errorf("failed to listen on uapi socket: %v", err)
|
||||
_ = uapiSock.Close()
|
||||
return nil, err
|
||||
}
|
||||
|
||||
|
||||
@@ -54,6 +54,14 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
|
||||
return wgCfg
|
||||
}
|
||||
|
||||
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||
return &WGUSPConfigurer{
|
||||
device: device,
|
||||
deviceName: deviceName,
|
||||
activityRecorder: activityRecorder,
|
||||
}
|
||||
}
|
||||
|
||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||
log.Debugf("adding Wireguard private key")
|
||||
key, err := wgtypes.ParseKey(privateKey)
|
||||
|
||||
@@ -79,7 +79,7 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||
)
|
||||
|
||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
||||
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
|
||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||
if err != nil {
|
||||
if cErr := tunIface.Close(); cErr != nil {
|
||||
|
||||
@@ -27,8 +27,8 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
nbnet "github.com/netbirdio/netbird/client/net"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/client/ssh"
|
||||
@@ -44,13 +44,13 @@ import (
|
||||
)
|
||||
|
||||
type ConnectClient struct {
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
doInitialAutoUpdate bool
|
||||
ctx context.Context
|
||||
config *profilemanager.Config
|
||||
statusRecorder *peer.Status
|
||||
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
engine *Engine
|
||||
engineMutex sync.Mutex
|
||||
updateManager *updater.Manager
|
||||
|
||||
persistSyncResponse bool
|
||||
}
|
||||
@@ -59,17 +59,19 @@ func NewConnectClient(
|
||||
ctx context.Context,
|
||||
config *profilemanager.Config,
|
||||
statusRecorder *peer.Status,
|
||||
doInitalAutoUpdate bool,
|
||||
) *ConnectClient {
|
||||
return &ConnectClient{
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
doInitialAutoUpdate: doInitalAutoUpdate,
|
||||
engineMutex: sync.Mutex{},
|
||||
ctx: ctx,
|
||||
config: config,
|
||||
statusRecorder: statusRecorder,
|
||||
engineMutex: sync.Mutex{},
|
||||
}
|
||||
}
|
||||
|
||||
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||
c.updateManager = um
|
||||
}
|
||||
|
||||
// Run with main logic.
|
||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||
return c.run(MobileDependency{}, runningChan, logPath)
|
||||
@@ -187,14 +189,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
stateManager := statemanager.New(path)
|
||||
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||
|
||||
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
|
||||
if err == nil {
|
||||
updateManager.CheckUpdateSuccess(c.ctx)
|
||||
if c.updateManager != nil {
|
||||
c.updateManager.CheckUpdateSuccess(c.ctx)
|
||||
}
|
||||
|
||||
inst := installer.New()
|
||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||
}
|
||||
inst := installer.New()
|
||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||
}
|
||||
|
||||
defer c.statusRecorder.ClientStop()
|
||||
@@ -308,7 +309,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
checks := loginResp.GetChecks()
|
||||
|
||||
c.engineMutex.Lock()
|
||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
||||
engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmClient,
|
||||
RelayManager: relayManager,
|
||||
StatusRecorder: c.statusRecorder,
|
||||
Checks: checks,
|
||||
StateManager: stateManager,
|
||||
UpdateManager: c.updateManager,
|
||||
}, mobileDependency)
|
||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||
c.engine = engine
|
||||
c.engineMutex.Unlock()
|
||||
@@ -318,21 +327,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
||||
return wrapErr(err)
|
||||
}
|
||||
|
||||
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
||||
// AutoUpdate will be true when the user click on "Connect" menu on the UI
|
||||
if c.doInitialAutoUpdate {
|
||||
log.Infof("start engine by ui, run auto-update check")
|
||||
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
||||
c.doInitialAutoUpdate = false
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||
state.Set(StatusConnected)
|
||||
|
||||
if runningChan != nil {
|
||||
close(runningChan)
|
||||
runningChan = nil
|
||||
select {
|
||||
case <-runningChan:
|
||||
default:
|
||||
close(runningChan)
|
||||
}
|
||||
}
|
||||
|
||||
<-engineCtx.Done()
|
||||
|
||||
60
client/internal/daemonaddr/resolve.go
Normal file
60
client/internal/daemonaddr/resolve.go
Normal file
@@ -0,0 +1,60 @@
|
||||
//go:build !windows && !ios && !android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
var scanDir = "/var/run/netbird"
|
||||
|
||||
// setScanDir overrides the scan directory (used by tests).
|
||||
func setScanDir(dir string) {
|
||||
scanDir = dir
|
||||
}
|
||||
|
||||
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
|
||||
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
|
||||
// mismatch between the netbird@.service template (which places the socket under
|
||||
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
|
||||
func ResolveUnixDaemonAddr(addr string) string {
|
||||
if !strings.HasPrefix(addr, "unix://") {
|
||||
return addr
|
||||
}
|
||||
|
||||
sockPath := strings.TrimPrefix(addr, "unix://")
|
||||
if _, err := os.Stat(sockPath); err == nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
entries, err := os.ReadDir(scanDir)
|
||||
if err != nil {
|
||||
return addr
|
||||
}
|
||||
|
||||
var found []string
|
||||
for _, e := range entries {
|
||||
if e.IsDir() {
|
||||
continue
|
||||
}
|
||||
if strings.HasSuffix(e.Name(), ".sock") {
|
||||
found = append(found, filepath.Join(scanDir, e.Name()))
|
||||
}
|
||||
}
|
||||
|
||||
switch len(found) {
|
||||
case 1:
|
||||
resolved := "unix://" + found[0]
|
||||
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||
return resolved
|
||||
case 0:
|
||||
return addr
|
||||
default:
|
||||
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
|
||||
return addr
|
||||
}
|
||||
}
|
||||
8
client/internal/daemonaddr/resolve_stub.go
Normal file
8
client/internal/daemonaddr/resolve_stub.go
Normal file
@@ -0,0 +1,8 @@
|
||||
//go:build windows || ios || android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
|
||||
func ResolveUnixDaemonAddr(addr string) string {
|
||||
return addr
|
||||
}
|
||||
121
client/internal/daemonaddr/resolve_test.go
Normal file
121
client/internal/daemonaddr/resolve_test.go
Normal file
@@ -0,0 +1,121 @@
|
||||
//go:build !windows && !ios && !android
|
||||
|
||||
package daemonaddr
|
||||
|
||||
import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
"testing"
|
||||
)
|
||||
|
||||
// createSockFile creates a regular file with a .sock extension.
|
||||
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
|
||||
// sufficient and avoids Unix socket path-length limits on macOS.
|
||||
func createSockFile(t *testing.T, path string) {
|
||||
t.Helper()
|
||||
if err := os.WriteFile(path, nil, 0o600); err != nil {
|
||||
t.Fatalf("failed to create test sock file at %s: %v", path, err)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
sock := filepath.Join(tmp, "netbird.sock")
|
||||
createSockFile(t, sock)
|
||||
|
||||
addr := "unix://" + sock
|
||||
got := ResolveUnixDaemonAddr(addr)
|
||||
if got != addr {
|
||||
t.Errorf("expected %s, got %s", addr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
// Default socket does not exist
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
// Create a scan dir with one socket
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
instanceSock := filepath.Join(sd, "main.sock")
|
||||
createSockFile(t, instanceSock)
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
expected := "unix://" + instanceSock
|
||||
if got != expected {
|
||||
t.Errorf("expected %s, got %s", expected, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
createSockFile(t, filepath.Join(sd, "main.sock"))
|
||||
createSockFile(t, filepath.Join(sd, "other.sock"))
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
sd := filepath.Join(tmp, "netbird")
|
||||
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||
t.Fatal(err)
|
||||
}
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(sd)
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
|
||||
addr := "tcp://127.0.0.1:41731"
|
||||
got := ResolveUnixDaemonAddr(addr)
|
||||
if got != addr {
|
||||
t.Errorf("expected %s, got %s", addr, got)
|
||||
}
|
||||
}
|
||||
|
||||
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
|
||||
tmp := t.TempDir()
|
||||
|
||||
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||
|
||||
origScanDir := scanDir
|
||||
setScanDir(filepath.Join(tmp, "nonexistent"))
|
||||
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||
|
||||
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||
if got != defaultAddr {
|
||||
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||
}
|
||||
}
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"github.com/netbirdio/netbird/client/anonymize"
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
|
||||
@@ -14,6 +14,8 @@ import (
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/hashicorp/go-multierror"
|
||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"golang.org/x/exp/maps"
|
||||
|
||||
@@ -22,6 +24,7 @@ import (
|
||||
|
||||
const (
|
||||
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
||||
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
|
||||
globalIPv4State = "State:/Network/Global/IPv4"
|
||||
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||
@@ -35,6 +38,14 @@ const (
|
||||
searchSuffix = "Search"
|
||||
matchSuffix = "Match"
|
||||
localSuffix = "Local"
|
||||
|
||||
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
|
||||
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
|
||||
maxDomainsPerResolverEntry = 50
|
||||
|
||||
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
|
||||
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
|
||||
maxDomainBytesPerResolverEntry = 1500
|
||||
)
|
||||
|
||||
type systemConfigurator struct {
|
||||
@@ -84,28 +95,23 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||
}
|
||||
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
var err error
|
||||
if len(matchDomains) != 0 {
|
||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing match domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(matchKey)
|
||||
if err := s.removeKeysContaining(matchSuffix); err != nil {
|
||||
log.Warnf("failed to remove old match keys: %v", err)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
if len(matchDomains) != 0 {
|
||||
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
|
||||
return fmt.Errorf("add match domains: %w", err)
|
||||
}
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
if len(searchDomains) != 0 {
|
||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
||||
} else {
|
||||
log.Infof("removing search domains from the system")
|
||||
err = s.removeKeyFromSystemConfig(searchKey)
|
||||
if err := s.removeKeysContaining(searchSuffix); err != nil {
|
||||
log.Warnf("failed to remove old search keys: %v", err)
|
||||
}
|
||||
if err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
if len(searchDomains) != 0 {
|
||||
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
}
|
||||
}
|
||||
s.updateState(stateManager)
|
||||
|
||||
@@ -149,8 +155,7 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
||||
|
||||
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
if len(s.createdKeys) == 0 {
|
||||
// return defaults for startup calls
|
||||
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
|
||||
return s.discoverExistingKeys()
|
||||
}
|
||||
|
||||
keys := make([]string, 0, len(s.createdKeys))
|
||||
@@ -160,6 +165,47 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||
return keys
|
||||
}
|
||||
|
||||
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
|
||||
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
|
||||
func (s *systemConfigurator) discoverExistingKeys() []string {
|
||||
dnsKeys, err := getSystemDNSKeys()
|
||||
if err != nil {
|
||||
log.Errorf("failed to get system DNS keys: %v", err)
|
||||
return nil
|
||||
}
|
||||
|
||||
var keys []string
|
||||
|
||||
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
|
||||
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
|
||||
if strings.Contains(dnsKeys, key) {
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
for _, suffix := range []string{searchSuffix, matchSuffix} {
|
||||
for i := 0; ; i++ {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||
if !strings.Contains(dnsKeys, key) {
|
||||
break
|
||||
}
|
||||
keys = append(keys, key)
|
||||
}
|
||||
}
|
||||
|
||||
return keys
|
||||
}
|
||||
|
||||
// getSystemDNSKeys gets all DNS keys
|
||||
func getSystemDNSKeys() (string, error) {
|
||||
command := "list .*DNS\nquit\n"
|
||||
out, err := runSystemConfigCommand(command)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(out), nil
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||
line := buildRemoveKeyOperation(key)
|
||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||
@@ -184,12 +230,11 @@ func (s *systemConfigurator) addLocalDNS() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := s.addSearchDomains(
|
||||
localKey,
|
||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
||||
); err != nil {
|
||||
return fmt.Errorf("add search domains: %w", err)
|
||||
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
|
||||
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
|
||||
return fmt.Errorf("add local dns state: %w", err)
|
||||
}
|
||||
s.createdKeys[localKey] = struct{}{}
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -280,28 +325,77 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
||||
return slices.Clone(s.origNameservers)
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, ip, port, true)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
|
||||
func splitDomainsIntoBatches(domains []string) [][]string {
|
||||
if len(domains) == 0 {
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
var batches [][]string
|
||||
var current []string
|
||||
currentBytes := 0
|
||||
|
||||
s.createdKeys[key] = struct{}{}
|
||||
for _, d := range domains {
|
||||
domainLen := len(d)
|
||||
newBytes := currentBytes + domainLen
|
||||
if currentBytes > 0 {
|
||||
newBytes++ // space separator
|
||||
}
|
||||
|
||||
return nil
|
||||
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
|
||||
batches = append(batches, current)
|
||||
current = nil
|
||||
currentBytes = 0
|
||||
}
|
||||
|
||||
current = append(current, d)
|
||||
if currentBytes > 0 {
|
||||
currentBytes += 1 + domainLen
|
||||
} else {
|
||||
currentBytes = domainLen
|
||||
}
|
||||
}
|
||||
|
||||
if len(current) > 0 {
|
||||
batches = append(batches, current)
|
||||
}
|
||||
|
||||
return batches
|
||||
}
|
||||
|
||||
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
|
||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
||||
if err != nil {
|
||||
return fmt.Errorf("add dns state: %w", err)
|
||||
// removeKeysContaining removes all created keys that contain the given substring.
|
||||
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
|
||||
var toRemove []string
|
||||
for key := range s.createdKeys {
|
||||
if strings.Contains(key, suffix) {
|
||||
toRemove = append(toRemove, key)
|
||||
}
|
||||
}
|
||||
var multiErr *multierror.Error
|
||||
for _, key := range toRemove {
|
||||
if err := s.removeKeyFromSystemConfig(key); err != nil {
|
||||
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
|
||||
}
|
||||
}
|
||||
return nberrors.FormatErrorOrNil(multiErr)
|
||||
}
|
||||
|
||||
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
|
||||
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
|
||||
batches := splitDomainsIntoBatches(domains)
|
||||
|
||||
for i, batch := range batches {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||
domainsStr := strings.Join(batch, " ")
|
||||
|
||||
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
|
||||
return fmt.Errorf("add dns state for batch %d: %w", i, err)
|
||||
}
|
||||
|
||||
s.createdKeys[key] = struct{}{}
|
||||
}
|
||||
|
||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
||||
|
||||
s.createdKeys[key] = struct{}{}
|
||||
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
|
||||
|
||||
return nil
|
||||
}
|
||||
@@ -364,7 +458,6 @@ func (s *systemConfigurator) flushDNSCache() error {
|
||||
if out, err := cmd.CombinedOutput(); err != nil {
|
||||
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
||||
}
|
||||
|
||||
log.Info("flushed DNS cache")
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -3,7 +3,10 @@
|
||||
package dns
|
||||
|
||||
import (
|
||||
"bufio"
|
||||
"bytes"
|
||||
"context"
|
||||
"fmt"
|
||||
"net/netip"
|
||||
"os/exec"
|
||||
"path/filepath"
|
||||
@@ -49,17 +52,22 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
|
||||
require.NoError(t, sm.PersistState(context.Background()))
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
// Collect all created keys for cleanup verification
|
||||
createdKeys := make([]string, 0, len(configurator.createdKeys))
|
||||
for key := range configurator.createdKeys {
|
||||
createdKeys = append(createdKeys, key)
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
for _, key := range createdKeys {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
_ = removeTestDNSKey(localKey)
|
||||
}()
|
||||
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
for _, key := range createdKeys {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
if exists {
|
||||
@@ -83,13 +91,223 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
||||
err = shutdownState.Cleanup()
|
||||
require.NoError(t, err)
|
||||
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
for _, key := range createdKeys {
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||
}
|
||||
}
|
||||
|
||||
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
|
||||
func generateShortDomains(count int) []string {
|
||||
domains := make([]string, 0, count)
|
||||
for i := range count {
|
||||
label := ""
|
||||
n := i
|
||||
for {
|
||||
label = string(rune('a'+n%26)) + label
|
||||
n = n/26 - 1
|
||||
if n < 0 {
|
||||
break
|
||||
}
|
||||
}
|
||||
domains = append(domains, label+".com")
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
|
||||
func generateLongDomains(count int) []string {
|
||||
domains := make([]string, 0, count)
|
||||
for i := range count {
|
||||
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
|
||||
}
|
||||
return domains
|
||||
}
|
||||
|
||||
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
|
||||
func readDomainsFromKey(t *testing.T, key string) []string {
|
||||
t.Helper()
|
||||
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
|
||||
out, err := cmd.Output()
|
||||
require.NoError(t, err, "scutil show should succeed")
|
||||
|
||||
var domains []string
|
||||
inArray := false
|
||||
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||
for scanner.Scan() {
|
||||
line := strings.TrimSpace(scanner.Text())
|
||||
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
|
||||
inArray = true
|
||||
continue
|
||||
}
|
||||
if inArray {
|
||||
if line == "}" {
|
||||
break
|
||||
}
|
||||
// lines look like: "0 : a.com"
|
||||
parts := strings.SplitN(line, " : ", 2)
|
||||
if len(parts) == 2 {
|
||||
domains = append(domains, parts[1])
|
||||
}
|
||||
}
|
||||
}
|
||||
require.NoError(t, scanner.Err())
|
||||
return domains
|
||||
}
|
||||
|
||||
func TestSplitDomainsIntoBatches(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
domains []string
|
||||
expectedCount int
|
||||
checkAllPresent bool
|
||||
}{
|
||||
{
|
||||
name: "empty",
|
||||
domains: nil,
|
||||
expectedCount: 0,
|
||||
},
|
||||
{
|
||||
name: "under_limit",
|
||||
domains: generateShortDomains(10),
|
||||
expectedCount: 1,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "at_element_limit",
|
||||
domains: generateShortDomains(50),
|
||||
expectedCount: 1,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "over_element_limit",
|
||||
domains: generateShortDomains(51),
|
||||
expectedCount: 2,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "triple_element_limit",
|
||||
domains: generateShortDomains(150),
|
||||
expectedCount: 3,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "long_domains_hit_byte_limit",
|
||||
domains: generateLongDomains(50),
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "500_short_domains",
|
||||
domains: generateShortDomains(500),
|
||||
expectedCount: 10,
|
||||
checkAllPresent: true,
|
||||
},
|
||||
{
|
||||
name: "500_long_domains",
|
||||
domains: generateLongDomains(500),
|
||||
checkAllPresent: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range tests {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
batches := splitDomainsIntoBatches(tc.domains)
|
||||
|
||||
if tc.expectedCount > 0 {
|
||||
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
|
||||
}
|
||||
|
||||
// Verify each batch respects limits
|
||||
for i, batch := range batches {
|
||||
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
|
||||
"batch %d exceeds element limit", i)
|
||||
|
||||
totalBytes := 0
|
||||
for j, d := range batch {
|
||||
if j > 0 {
|
||||
totalBytes++
|
||||
}
|
||||
totalBytes += len(d)
|
||||
}
|
||||
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
|
||||
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
|
||||
}
|
||||
|
||||
if tc.checkAllPresent {
|
||||
var all []string
|
||||
for _, batch := range batches {
|
||||
all = append(all, batch...)
|
||||
}
|
||||
assert.Equal(t, tc.domains, all, "all domains should be present in order")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
|
||||
// and verifies all domains are readable across multiple scutil keys.
|
||||
func TestMatchDomainBatching(t *testing.T) {
|
||||
if testing.Short() {
|
||||
t.Skip("skipping scutil integration test in short mode")
|
||||
}
|
||||
|
||||
testCases := []struct {
|
||||
name string
|
||||
count int
|
||||
generator func(int) []string
|
||||
}{
|
||||
{"short_10", 10, generateShortDomains},
|
||||
{"short_50", 50, generateShortDomains},
|
||||
{"short_100", 100, generateShortDomains},
|
||||
{"short_200", 200, generateShortDomains},
|
||||
{"short_500", 500, generateShortDomains},
|
||||
{"long_10", 10, generateLongDomains},
|
||||
{"long_50", 50, generateLongDomains},
|
||||
{"long_100", 100, generateLongDomains},
|
||||
{"long_200", 200, generateLongDomains},
|
||||
{"long_500", 500, generateLongDomains},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
configurator := &systemConfigurator{
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
defer func() {
|
||||
for key := range configurator.createdKeys {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
}()
|
||||
|
||||
domains := tc.generator(tc.count)
|
||||
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
|
||||
require.NoError(t, err)
|
||||
|
||||
batches := splitDomainsIntoBatches(domains)
|
||||
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
|
||||
|
||||
// Read back all domains from all batched keys
|
||||
var got []string
|
||||
for i := range batches {
|
||||
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
|
||||
exists, err := checkDNSKeyExists(key)
|
||||
require.NoError(t, err)
|
||||
require.True(t, exists, "key %s should exist", key)
|
||||
|
||||
got = append(got, readDomainsFromKey(t, key)...)
|
||||
}
|
||||
|
||||
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
|
||||
assert.Equal(t, tc.count, len(got), "all domains should be readable")
|
||||
assert.Equal(t, domains, got, "domains should match in order")
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func checkDNSKeyExists(key string) (bool, error) {
|
||||
cmd := exec.Command(scutilPath)
|
||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||
@@ -158,15 +376,15 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man
|
||||
createdKeys: make(map[string]struct{}),
|
||||
}
|
||||
|
||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||
|
||||
cleanup := func() {
|
||||
_ = sm.Stop(context.Background())
|
||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
||||
for key := range configurator.createdKeys {
|
||||
_ = removeTestDNSKey(key)
|
||||
}
|
||||
// Also clean up old-format keys and local key in case they exist
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
|
||||
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
|
||||
}
|
||||
|
||||
return configurator, sm, cleanup
|
||||
|
||||
@@ -277,7 +277,7 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
||||
}
|
||||
}
|
||||
|
||||
log.Infof("added %d NRPT rules for %d domains. Domain list: %v", ruleIndex, len(domains), domains)
|
||||
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
|
||||
return ruleIndex, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID {
|
||||
return "local-resolver"
|
||||
}
|
||||
|
||||
func (d *Resolver) ProbeAvailability() {}
|
||||
func (d *Resolver) ProbeAvailability(context.Context) {}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
|
||||
@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
||||
}
|
||||
}
|
||||
|
||||
if serverDomains.Flow != "" {
|
||||
domains = append(domains, serverDomains.Flow)
|
||||
}
|
||||
// Flow receiver domain is intentionally excluded from caching.
|
||||
// Cloud providers may rotate the IP behind this domain; a stale cached record
|
||||
// causes TLS certificate verification failures on reconnect.
|
||||
|
||||
for _, stun := range serverDomains.Stuns {
|
||||
if stun != "" {
|
||||
|
||||
@@ -391,7 +391,8 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
}
|
||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||
|
||||
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
|
||||
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
|
||||
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
|
||||
partialDomains := dnsconfig.ServerDomains{
|
||||
Flow: "github.com",
|
||||
}
|
||||
@@ -400,10 +401,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||
}
|
||||
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
|
||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
|
||||
|
||||
finalDomains := resolver.GetCachedDomains()
|
||||
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
|
||||
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
|
||||
|
||||
domainStrings := make([]string, len(finalDomains))
|
||||
for i, d := range finalDomains {
|
||||
@@ -412,5 +413,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
||||
assert.Contains(t, domainStrings, "example.org")
|
||||
assert.Contains(t, domainStrings, "google.com")
|
||||
assert.Contains(t, domainStrings, "cloudflare.com")
|
||||
assert.Contains(t, domainStrings, "github.com")
|
||||
assert.NotContains(t, domainStrings, "github.com")
|
||||
}
|
||||
|
||||
@@ -104,12 +104,16 @@ type DefaultServer struct {
|
||||
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
|
||||
probeMu sync.Mutex
|
||||
probeCancel context.CancelFunc
|
||||
probeWg sync.WaitGroup
|
||||
}
|
||||
|
||||
type handlerWithStop interface {
|
||||
dns.Handler
|
||||
Stop()
|
||||
ProbeAvailability()
|
||||
ProbeAvailability(context.Context)
|
||||
ID() types.HandlerID
|
||||
}
|
||||
|
||||
@@ -362,7 +366,13 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
||||
|
||||
// Stop stops the server
|
||||
func (s *DefaultServer) Stop() {
|
||||
s.probeMu.Lock()
|
||||
if s.probeCancel != nil {
|
||||
s.probeCancel()
|
||||
}
|
||||
s.ctxCancel()
|
||||
s.probeMu.Unlock()
|
||||
s.probeWg.Wait()
|
||||
s.shutdownWg.Wait()
|
||||
|
||||
s.mux.Lock()
|
||||
@@ -479,7 +489,8 @@ func (s *DefaultServer) SearchDomains() []string {
|
||||
}
|
||||
|
||||
// ProbeAvailability tests each upstream group's servers for availability
|
||||
// and deactivates the group if no server responds
|
||||
// and deactivates the group if no server responds.
|
||||
// If a previous probe is still running, it will be cancelled before starting a new one.
|
||||
func (s *DefaultServer) ProbeAvailability() {
|
||||
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||
skipProbe, err := strconv.ParseBool(val)
|
||||
@@ -492,15 +503,52 @@ func (s *DefaultServer) ProbeAvailability() {
|
||||
}
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
wg.Add(1)
|
||||
go func(mux handlerWithStop) {
|
||||
defer wg.Done()
|
||||
mux.ProbeAvailability()
|
||||
}(mux.handler)
|
||||
s.probeMu.Lock()
|
||||
|
||||
// don't start probes on a stopped server
|
||||
if s.ctx.Err() != nil {
|
||||
s.probeMu.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// cancel any running probe
|
||||
if s.probeCancel != nil {
|
||||
s.probeCancel()
|
||||
s.probeCancel = nil
|
||||
}
|
||||
|
||||
// wait for the previous probe goroutines to finish while holding
|
||||
// the mutex so no other caller can start a new probe concurrently
|
||||
s.probeWg.Wait()
|
||||
|
||||
// start a new probe
|
||||
probeCtx, probeCancel := context.WithCancel(s.ctx)
|
||||
s.probeCancel = probeCancel
|
||||
|
||||
s.probeWg.Add(1)
|
||||
defer s.probeWg.Done()
|
||||
|
||||
// Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers.
|
||||
s.mux.Lock()
|
||||
handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap))
|
||||
for _, mux := range s.dnsMuxMap {
|
||||
handlers = append(handlers, mux.handler)
|
||||
}
|
||||
s.mux.Unlock()
|
||||
|
||||
var wg sync.WaitGroup
|
||||
for _, handler := range handlers {
|
||||
wg.Add(1)
|
||||
go func(h handlerWithStop) {
|
||||
defer wg.Done()
|
||||
h.ProbeAvailability(probeCtx)
|
||||
}(handler)
|
||||
}
|
||||
|
||||
s.probeMu.Unlock()
|
||||
|
||||
wg.Wait()
|
||||
probeCancel()
|
||||
}
|
||||
|
||||
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||
|
||||
@@ -1065,7 +1065,7 @@ type mockHandler struct {
|
||||
|
||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||
func (m *mockHandler) Stop() {}
|
||||
func (m *mockHandler) ProbeAvailability() {}
|
||||
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||
|
||||
type mockService struct{}
|
||||
|
||||
@@ -6,6 +6,7 @@ import (
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -69,7 +70,7 @@ func (s *serviceViaListener) Listen() error {
|
||||
return fmt.Errorf("eval listen address: %w", err)
|
||||
}
|
||||
s.listenIP = s.listenIP.Unmap()
|
||||
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
||||
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||
log.Debugf("starting dns on %s", s.server.Addr)
|
||||
go func() {
|
||||
s.setListenerStatus(true)
|
||||
@@ -186,7 +187,7 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
||||
}
|
||||
|
||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
||||
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
|
||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||
if err != nil {
|
||||
|
||||
@@ -65,6 +65,7 @@ type upstreamResolverBase struct {
|
||||
mutex sync.Mutex
|
||||
reactivatePeriod time.Duration
|
||||
upstreamTimeout time.Duration
|
||||
wg sync.WaitGroup
|
||||
|
||||
deactivate func(error)
|
||||
reactivate func()
|
||||
@@ -115,6 +116,11 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
|
||||
func (u *upstreamResolverBase) Stop() {
|
||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||
u.cancel()
|
||||
|
||||
u.mutex.Lock()
|
||||
u.wg.Wait()
|
||||
u.mutex.Unlock()
|
||||
|
||||
}
|
||||
|
||||
// ServeDNS handles a DNS request
|
||||
@@ -260,16 +266,10 @@ func formatFailures(failures []upstreamFailure) string {
|
||||
|
||||
// ProbeAvailability tests all upstream servers simultaneously and
|
||||
// disables the resolver if none work
|
||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||
u.mutex.Lock()
|
||||
defer u.mutex.Unlock()
|
||||
|
||||
select {
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// avoid probe if upstreams could resolve at least one query
|
||||
if u.successCount.Load() > 0 {
|
||||
return
|
||||
@@ -279,31 +279,39 @@ func (u *upstreamResolverBase) ProbeAvailability() {
|
||||
var mu sync.Mutex
|
||||
var wg sync.WaitGroup
|
||||
|
||||
var errors *multierror.Error
|
||||
var errs *multierror.Error
|
||||
for _, upstream := range u.upstreamServers {
|
||||
upstream := upstream
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
go func(upstream netip.AddrPort) {
|
||||
defer wg.Done()
|
||||
err := u.testNameserver(upstream, 500*time.Millisecond)
|
||||
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
|
||||
if err != nil {
|
||||
errors = multierror.Append(errors, err)
|
||||
mu.Lock()
|
||||
errs = multierror.Append(errs, err)
|
||||
mu.Unlock()
|
||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||
return
|
||||
}
|
||||
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
success = true
|
||||
}()
|
||||
mu.Unlock()
|
||||
}(upstream)
|
||||
}
|
||||
|
||||
wg.Wait()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-u.ctx.Done():
|
||||
return
|
||||
default:
|
||||
}
|
||||
|
||||
// didn't find a working upstream server, let's disable and try later
|
||||
if !success {
|
||||
u.disable(errors.ErrorOrNil())
|
||||
u.disable(errs.ErrorOrNil())
|
||||
|
||||
if u.statusRecorder == nil {
|
||||
return
|
||||
@@ -339,7 +347,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
}
|
||||
|
||||
for _, upstream := range u.upstreamServers {
|
||||
if err := u.testNameserver(upstream, probeTimeout); err != nil {
|
||||
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
|
||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||
} else {
|
||||
// at least one upstream server is available, stop probing
|
||||
@@ -351,16 +359,22 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
||||
return fmt.Errorf("upstream check call error")
|
||||
}
|
||||
|
||||
err := backoff.Retry(operation, exponentialBackOff)
|
||||
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||
if err != nil {
|
||||
log.Warn(err)
|
||||
if errors.Is(err, context.Canceled) {
|
||||
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||
} else {
|
||||
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||
u.successCount.Add(1)
|
||||
u.reactivate()
|
||||
u.mutex.Lock()
|
||||
u.disabled = false
|
||||
u.mutex.Unlock()
|
||||
}
|
||||
|
||||
// isTimeout returns true if the given error is a network timeout error.
|
||||
@@ -383,7 +397,11 @@ func (u *upstreamResolverBase) disable(err error) {
|
||||
u.successCount.Store(0)
|
||||
u.deactivate(err)
|
||||
u.disabled = true
|
||||
go u.waitUntilResponse()
|
||||
u.wg.Add(1)
|
||||
go func() {
|
||||
defer u.wg.Done()
|
||||
u.waitUntilResponse()
|
||||
}()
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
@@ -394,13 +412,18 @@ func (u *upstreamResolverBase) upstreamServersString() string {
|
||||
return strings.Join(servers, ", ")
|
||||
}
|
||||
|
||||
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
|
||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
||||
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
|
||||
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
|
||||
defer cancel()
|
||||
|
||||
if externalCtx != nil {
|
||||
stop2 := context.AfterFunc(externalCtx, cancel)
|
||||
defer stop2()
|
||||
}
|
||||
|
||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||
|
||||
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
|
||||
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
|
||||
return err
|
||||
}
|
||||
|
||||
|
||||
@@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
||||
reactivated = true
|
||||
}
|
||||
|
||||
resolver.ProbeAvailability()
|
||||
resolver.ProbeAvailability(context.TODO())
|
||||
|
||||
if !failed {
|
||||
t.Errorf("expected that resolving was deactivated")
|
||||
|
||||
@@ -28,14 +28,15 @@ import (
|
||||
"github.com/netbirdio/netbird/client/firewall"
|
||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||
"github.com/netbirdio/netbird/client/iface"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/device"
|
||||
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||
"github.com/netbirdio/netbird/client/internal/acl"
|
||||
"github.com/netbirdio/netbird/client/internal/debug"
|
||||
"github.com/netbirdio/netbird/client/internal/dns"
|
||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||
"github.com/netbirdio/netbird/client/internal/expose"
|
||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||
@@ -50,16 +51,14 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updater"
|
||||
"github.com/netbirdio/netbird/client/jobexec"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
|
||||
"github.com/netbirdio/netbird/client/system"
|
||||
nbdns "github.com/netbirdio/netbird/dns"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
"github.com/netbirdio/netbird/shared/management/domain"
|
||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
@@ -75,13 +74,11 @@ import (
|
||||
const (
|
||||
PeerConnectionTimeoutMax = 45000 // ms
|
||||
PeerConnectionTimeoutMin = 30000 // ms
|
||||
connInitLimit = 200
|
||||
disableAutoUpdate = "disabled"
|
||||
)
|
||||
|
||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||
|
||||
// EngineConfig is a config for the Engine
|
||||
type EngineConfig struct {
|
||||
WgPort int
|
||||
WgIfaceName string
|
||||
@@ -143,6 +140,17 @@ type EngineConfig struct {
|
||||
LogPath string
|
||||
}
|
||||
|
||||
// EngineServices holds the external service dependencies required by the Engine.
|
||||
type EngineServices struct {
|
||||
SignalClient signal.Client
|
||||
MgmClient mgm.Client
|
||||
RelayManager *relayClient.Manager
|
||||
StatusRecorder *peer.Status
|
||||
Checks []*mgmProto.Checks
|
||||
StateManager *statemanager.Manager
|
||||
UpdateManager *updater.Manager
|
||||
}
|
||||
|
||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||
type Engine struct {
|
||||
// signal is a Signal Service client
|
||||
@@ -208,11 +216,10 @@ type Engine struct {
|
||||
syncRespMux sync.RWMutex
|
||||
persistSyncResponse bool
|
||||
latestSyncResponse *mgmProto.SyncResponse
|
||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||
flowManager nftypes.FlowManager
|
||||
|
||||
// auto-update
|
||||
updateManager *updatemanager.Manager
|
||||
updateManager *updater.Manager
|
||||
|
||||
// WireGuard interface monitor
|
||||
wgIfaceMonitor *WGIfaceMonitor
|
||||
@@ -224,6 +231,8 @@ type Engine struct {
|
||||
|
||||
jobExecutor *jobexec.Executor
|
||||
jobExecutorWG sync.WaitGroup
|
||||
|
||||
exposeManager *expose.Manager
|
||||
}
|
||||
|
||||
// Peer is an instance of the Connection Peer
|
||||
@@ -240,22 +249,17 @@ type localIpUpdater interface {
|
||||
func NewEngine(
|
||||
clientCtx context.Context,
|
||||
clientCancel context.CancelFunc,
|
||||
signalClient signal.Client,
|
||||
mgmClient mgm.Client,
|
||||
relayManager *relayClient.Manager,
|
||||
config *EngineConfig,
|
||||
services EngineServices,
|
||||
mobileDep MobileDependency,
|
||||
statusRecorder *peer.Status,
|
||||
checks []*mgmProto.Checks,
|
||||
stateManager *statemanager.Manager,
|
||||
) *Engine {
|
||||
engine := &Engine{
|
||||
clientCtx: clientCtx,
|
||||
clientCancel: clientCancel,
|
||||
signal: signalClient,
|
||||
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
||||
mgmClient: mgmClient,
|
||||
relayManager: relayManager,
|
||||
signal: services.SignalClient,
|
||||
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||
mgmClient: services.MgmClient,
|
||||
relayManager: services.RelayManager,
|
||||
peerStore: peerstore.NewConnStore(),
|
||||
syncMsgMux: &sync.Mutex{},
|
||||
config: config,
|
||||
@@ -263,12 +267,12 @@ func NewEngine(
|
||||
STUNs: []*stun.URI{},
|
||||
TURNs: []*stun.URI{},
|
||||
networkSerial: 0,
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
checks: checks,
|
||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
||||
statusRecorder: services.StatusRecorder,
|
||||
stateManager: services.StateManager,
|
||||
checks: services.Checks,
|
||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||
jobExecutor: jobexec.NewExecutor(),
|
||||
updateManager: services.UpdateManager,
|
||||
}
|
||||
|
||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||
@@ -311,7 +315,7 @@ func (e *Engine) Stop() error {
|
||||
}
|
||||
|
||||
if e.updateManager != nil {
|
||||
e.updateManager.Stop()
|
||||
e.updateManager.SetDownloadOnly()
|
||||
}
|
||||
|
||||
log.Info("cleaning up status recorder states")
|
||||
@@ -419,6 +423,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
e.cancel()
|
||||
}
|
||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||
|
||||
wgIface, err := e.newWgIface()
|
||||
if err != nil {
|
||||
@@ -560,13 +565,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
|
||||
e.handleAutoUpdateVersion(autoUpdateSettings, true)
|
||||
}
|
||||
|
||||
func (e *Engine) createFirewall() error {
|
||||
if e.config.DisableFirewall {
|
||||
log.Infof("firewall is disabled")
|
||||
@@ -794,39 +792,22 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
|
||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||
if e.updateManager == nil {
|
||||
return
|
||||
}
|
||||
|
||||
if autoUpdateSettings == nil {
|
||||
return
|
||||
}
|
||||
|
||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
||||
|
||||
// Stop and cleanup if disabled
|
||||
if e.updateManager != nil && disabled {
|
||||
log.Infof("auto-update is disabled, stopping update manager")
|
||||
e.updateManager.Stop()
|
||||
e.updateManager = nil
|
||||
if autoUpdateSettings.Version == disableAutoUpdate {
|
||||
log.Infof("auto-update is disabled")
|
||||
e.updateManager.SetDownloadOnly()
|
||||
return
|
||||
}
|
||||
|
||||
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
|
||||
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
|
||||
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
|
||||
return
|
||||
}
|
||||
|
||||
// Start manager if needed
|
||||
if e.updateManager == nil {
|
||||
log.Infof("starting auto-update manager")
|
||||
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
e.updateManager = updateManager
|
||||
e.updateManager.Start(e.ctx)
|
||||
}
|
||||
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version)
|
||||
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
|
||||
}
|
||||
|
||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
@@ -843,7 +824,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||
}
|
||||
|
||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
|
||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||
}
|
||||
|
||||
if update.GetNetbirdConfig() != nil {
|
||||
@@ -1008,10 +989,11 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
||||
return errors.New("wireguard interface is not initialized")
|
||||
}
|
||||
|
||||
// Cannot update the IP address without restarting the engine because
|
||||
// the firewall, route manager, and other components cache the old address
|
||||
if e.wgInterface.Address().String() != conf.Address {
|
||||
log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address)
|
||||
log.Infof("peer IP address changed from %s to %s, restarting client", e.wgInterface.Address().String(), conf.Address)
|
||||
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||
e.clientCancel()
|
||||
return ErrResetConnection
|
||||
}
|
||||
|
||||
if conf.GetSshConfig() != nil {
|
||||
@@ -1316,8 +1298,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
||||
|
||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||
e.dnsServer.ProbeAvailability()
|
||||
|
||||
go e.dnsServer.ProbeAvailability()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -1539,7 +1520,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||
RelayManager: e.relayManager,
|
||||
SrWatcher: e.srWatcher,
|
||||
Semaphore: e.connSemaphore,
|
||||
}
|
||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||
if err != nil {
|
||||
@@ -1562,8 +1542,10 @@ func (e *Engine) receiveSignalEvents() {
|
||||
defer e.shutdownWg.Done()
|
||||
// connect to a stream of messages coming from the signal server
|
||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||
start := time.Now()
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
gotLock := time.Since(start)
|
||||
|
||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||
if e.ctx.Err() != nil {
|
||||
@@ -1587,6 +1569,8 @@ func (e *Engine) receiveSignalEvents() {
|
||||
return err
|
||||
}
|
||||
|
||||
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
|
||||
|
||||
if msg.Body.Type == sProto.Body_OFFER {
|
||||
conn.OnRemoteOffer(*offerAnswer)
|
||||
} else {
|
||||
@@ -1820,11 +1804,18 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
||||
return e.routeManager
|
||||
}
|
||||
|
||||
// GetFirewallManager returns the firewall manager
|
||||
// GetFirewallManager returns the firewall manager.
|
||||
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||
return e.firewall
|
||||
}
|
||||
|
||||
// GetExposeManager returns the expose session manager.
|
||||
func (e *Engine) GetExposeManager() *expose.Manager {
|
||||
e.syncMsgMux.Lock()
|
||||
defer e.syncMsgMux.Unlock()
|
||||
return e.exposeManager
|
||||
}
|
||||
|
||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||
iface, err := net.InterfaceByName(ifaceName)
|
||||
if err != nil {
|
||||
|
||||
@@ -251,9 +251,6 @@ func TestEngine_SSH(t *testing.T) {
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(
|
||||
ctx, cancel,
|
||||
&signal.MockClient{},
|
||||
&mgmt.MockClient{},
|
||||
relayMgr,
|
||||
&EngineConfig{
|
||||
WgIfaceName: "utun101",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
@@ -263,10 +260,13 @@ func TestEngine_SSH(t *testing.T) {
|
||||
MTU: iface.DefaultMTU,
|
||||
SSHKey: sshKey,
|
||||
},
|
||||
EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
},
|
||||
MobileDependency{},
|
||||
peer.NewRecorder("https://mgm"),
|
||||
nil,
|
||||
nil,
|
||||
)
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -428,13 +428,18 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
||||
defer cancel()
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun102",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
|
||||
wgIface := &MockWGIface{
|
||||
NameFunc: func() string { return "utun102" },
|
||||
@@ -647,13 +652,18 @@ func TestEngine_Sync(t *testing.T) {
|
||||
return nil
|
||||
}
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: "utun103",
|
||||
WgAddr: "100.64.0.1/24",
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
engine.dnsServer = &dns.MockServer{
|
||||
@@ -812,13 +822,18 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
if err != nil {
|
||||
@@ -1014,13 +1029,18 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
||||
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||
WgIfaceName: wgIfaceName,
|
||||
WgAddr: wgAddr,
|
||||
WgPrivateKey: key,
|
||||
WgPort: 33100,
|
||||
MTU: iface.DefaultMTU,
|
||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
||||
}, EngineServices{
|
||||
SignalClient: &signal.MockClient{},
|
||||
MgmClient: &mgmt.MockClient{},
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{})
|
||||
engine.ctx = ctx
|
||||
|
||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||
@@ -1546,7 +1566,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
||||
}
|
||||
|
||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
|
||||
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||
SignalClient: signalClient,
|
||||
MgmClient: mgmtClient,
|
||||
RelayManager: relayMgr,
|
||||
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||
}, MobileDependency{}), nil
|
||||
e.ctx = ctx
|
||||
return e, err
|
||||
}
|
||||
|
||||
97
client/internal/expose/manager.go
Normal file
97
client/internal/expose/manager.go
Normal file
@@ -0,0 +1,97 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"time"
|
||||
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const renewTimeout = 10 * time.Second
|
||||
|
||||
// Response holds the response from exposing a service.
|
||||
type Response struct {
|
||||
ServiceName string
|
||||
ServiceURL string
|
||||
Domain string
|
||||
PortAutoAssigned bool
|
||||
}
|
||||
|
||||
type Request struct {
|
||||
NamePrefix string
|
||||
Domain string
|
||||
Port uint16
|
||||
Protocol int
|
||||
Pin string
|
||||
Password string
|
||||
UserGroups []string
|
||||
ListenPort uint16
|
||||
}
|
||||
|
||||
type ManagementClient interface {
|
||||
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
|
||||
RenewExpose(ctx context.Context, domain string) error
|
||||
StopExpose(ctx context.Context, domain string) error
|
||||
}
|
||||
|
||||
// Manager handles expose session lifecycle via the management client.
|
||||
type Manager struct {
|
||||
mgmClient ManagementClient
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
// NewManager creates a new expose Manager using the given management client.
|
||||
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
|
||||
return &Manager{mgmClient: mgmClient, ctx: ctx}
|
||||
}
|
||||
|
||||
// Expose creates a new expose session via the management server.
|
||||
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
|
||||
log.Infof("exposing service on port %d", req.Port)
|
||||
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
log.Infof("expose session created for %s", resp.Domain)
|
||||
|
||||
return fromClientExposeResponse(resp), nil
|
||||
}
|
||||
|
||||
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
|
||||
ticker := time.NewTicker(30 * time.Second)
|
||||
defer ticker.Stop()
|
||||
defer m.stop(domain)
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Infof("context canceled, stopping keep alive for %s", domain)
|
||||
|
||||
return nil
|
||||
case <-ticker.C:
|
||||
if err := m.renew(ctx, domain); err != nil {
|
||||
log.Errorf("renewing expose session for %s: %v", domain, err)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// renew extends the TTL of an active expose session.
|
||||
func (m *Manager) renew(ctx context.Context, domain string) error {
|
||||
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
|
||||
defer cancel()
|
||||
return m.mgmClient.RenewExpose(renewCtx, domain)
|
||||
}
|
||||
|
||||
// stop terminates an active expose session.
|
||||
func (m *Manager) stop(domain string) {
|
||||
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
|
||||
defer cancel()
|
||||
err := m.mgmClient.StopExpose(stopCtx, domain)
|
||||
if err != nil {
|
||||
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
|
||||
}
|
||||
}
|
||||
95
client/internal/expose/manager_test.go
Normal file
95
client/internal/expose/manager_test.go
Normal file
@@ -0,0 +1,95 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
func TestManager_Expose_Success(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||
return &mgm.ExposeResponse{
|
||||
ServiceName: "my-service",
|
||||
ServiceURL: "https://my-service.example.com",
|
||||
Domain: "my-service.example.com",
|
||||
}, nil
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
result, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
|
||||
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
|
||||
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
|
||||
}
|
||||
|
||||
func TestManager_Expose_Error(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||
return nil, errors.New("permission denied")
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
_, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||
require.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
|
||||
}
|
||||
|
||||
func TestManager_Renew_Success(t *testing.T) {
|
||||
mock := &mgm.MockClient{
|
||||
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
|
||||
return nil
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(context.Background(), mock)
|
||||
err := m.renew(context.Background(), "my-service.example.com")
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestManager_Renew_Timeout(t *testing.T) {
|
||||
ctx, cancel := context.WithCancel(context.Background())
|
||||
cancel()
|
||||
|
||||
mock := &mgm.MockClient{
|
||||
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||
return ctx.Err()
|
||||
},
|
||||
}
|
||||
|
||||
m := NewManager(ctx, mock)
|
||||
err := m.renew(ctx, "my-service.example.com")
|
||||
require.Error(t, err)
|
||||
}
|
||||
|
||||
func TestNewRequest(t *testing.T) {
|
||||
req := &daemonProto.ExposeServiceRequest{
|
||||
Port: 8080,
|
||||
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
|
||||
Pin: "123456",
|
||||
Password: "secret",
|
||||
UserGroups: []string{"group1", "group2"},
|
||||
Domain: "custom.example.com",
|
||||
NamePrefix: "my-prefix",
|
||||
}
|
||||
|
||||
exposeReq := NewRequest(req)
|
||||
|
||||
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
|
||||
assert.Equal(t, int(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
|
||||
assert.Equal(t, "secret", exposeReq.Password, "password should match")
|
||||
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
|
||||
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
|
||||
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
|
||||
}
|
||||
42
client/internal/expose/request.go
Normal file
42
client/internal/expose/request.go
Normal file
@@ -0,0 +1,42 @@
|
||||
package expose
|
||||
|
||||
import (
|
||||
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||
)
|
||||
|
||||
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
|
||||
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
|
||||
return &Request{
|
||||
Port: uint16(req.Port),
|
||||
Protocol: int(req.Protocol),
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
Domain: req.Domain,
|
||||
NamePrefix: req.NamePrefix,
|
||||
ListenPort: uint16(req.ListenPort),
|
||||
}
|
||||
}
|
||||
|
||||
func toClientExposeRequest(req Request) mgm.ExposeRequest {
|
||||
return mgm.ExposeRequest{
|
||||
NamePrefix: req.NamePrefix,
|
||||
Domain: req.Domain,
|
||||
Port: req.Port,
|
||||
Protocol: req.Protocol,
|
||||
Pin: req.Pin,
|
||||
Password: req.Password,
|
||||
UserGroups: req.UserGroups,
|
||||
ListenPort: req.ListenPort,
|
||||
}
|
||||
}
|
||||
|
||||
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
|
||||
return &Response{
|
||||
ServiceName: response.ServiceName,
|
||||
Domain: response.Domain,
|
||||
ServiceURL: response.ServiceURL,
|
||||
PortAutoAssigned: response.PortAutoAssigned,
|
||||
}
|
||||
}
|
||||
@@ -22,51 +22,56 @@ func prepareFd() (int, error) {
|
||||
|
||||
func routeCheck(ctx context.Context, fd int, nexthopv4, nexthopv6 systemops.Nexthop) error {
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
default:
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
// Wait until fd is readable or context is cancelled, to avoid a busy-loop
|
||||
// when the routing socket returns EAGAIN (e.g. immediately after wakeup).
|
||||
if err := waitReadable(ctx, fd); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
buf := make([]byte, 2048)
|
||||
n, err := unix.Read(fd, buf)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EAGAIN) || errors.Is(err, unix.EINTR) {
|
||||
continue
|
||||
}
|
||||
if errors.Is(err, unix.EBADF) || errors.Is(err, unix.EINVAL) {
|
||||
return fmt.Errorf("routing socket closed: %w", err)
|
||||
}
|
||||
return fmt.Errorf("read routing socket: %w", err)
|
||||
}
|
||||
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
if !errors.Is(err, unix.EBADF) && !errors.Is(err, unix.EINVAL) {
|
||||
log.Warnf("Network monitor: failed to read from routing socket: %v", err)
|
||||
}
|
||||
continue
|
||||
}
|
||||
if n < unix.SizeofRtMsghdr {
|
||||
log.Debugf("Network monitor: read from routing socket returned less than expected: %d bytes", n)
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
msg := (*unix.RtMsghdr)(unsafe.Pointer(&buf[0]))
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
// handle route changes
|
||||
case unix.RTM_ADD, syscall.RTM_DELETE:
|
||||
route, err := parseRouteMessage(buf[:n])
|
||||
if err != nil {
|
||||
log.Debugf("Network monitor: error parsing routing message: %v", err)
|
||||
continue
|
||||
}
|
||||
|
||||
if route.Dst.Bits() != 0 {
|
||||
continue
|
||||
}
|
||||
|
||||
intf := "<nil>"
|
||||
if route.Interface != nil {
|
||||
intf = route.Interface.Name
|
||||
}
|
||||
switch msg.Type {
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
case unix.RTM_ADD:
|
||||
log.Infof("Network monitor: default route changed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
case unix.RTM_DELETE:
|
||||
if nexthopv4.Intf != nil && route.Gw.Compare(nexthopv4.IP) == 0 || nexthopv6.Intf != nil && route.Gw.Compare(nexthopv6.IP) == 0 {
|
||||
log.Infof("Network monitor: default route removed: via %s, interface %s", route.Gw, intf)
|
||||
return nil
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -90,3 +95,33 @@ func parseRouteMessage(buf []byte) (*systemops.Route, error) {
|
||||
|
||||
return systemops.MsgToRoute(msg)
|
||||
}
|
||||
|
||||
// waitReadable blocks until fd has data to read, or ctx is cancelled.
|
||||
func waitReadable(ctx context.Context, fd int) error {
|
||||
var fdset unix.FdSet
|
||||
if fd < 0 || fd/unix.NFDBITS >= len(fdset.Bits) {
|
||||
return fmt.Errorf("fd %d out of range for FdSet", fd)
|
||||
}
|
||||
|
||||
for {
|
||||
if err := ctx.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
fdset = unix.FdSet{}
|
||||
fdset.Set(fd)
|
||||
// Use a 1-second timeout so we can re-check ctx periodically.
|
||||
tv := unix.Timeval{Sec: 1}
|
||||
n, err := unix.Select(fd+1, &fdset, nil, nil, &tv)
|
||||
if err != nil {
|
||||
if errors.Is(err, unix.EINTR) {
|
||||
continue
|
||||
}
|
||||
return fmt.Errorf("select on routing socket: %w", err)
|
||||
}
|
||||
if n > 0 {
|
||||
return nil
|
||||
}
|
||||
// timeout — loop back and re-check ctx
|
||||
}
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package peer
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"math/rand"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
@@ -25,7 +24,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/route"
|
||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
type ServiceDependencies struct {
|
||||
@@ -34,7 +32,6 @@ type ServiceDependencies struct {
|
||||
IFaceDiscover stdnet.ExternalIFaceDiscover
|
||||
RelayManager *relayClient.Manager
|
||||
SrWatcher *guard.SRWatcher
|
||||
Semaphore *semaphoregroup.SemaphoreGroup
|
||||
PeerConnDispatcher *dispatcher.ConnectionDispatcher
|
||||
}
|
||||
|
||||
@@ -111,9 +108,8 @@ type Conn struct {
|
||||
wgProxyRelay wgproxy.Proxy
|
||||
handshaker *Handshaker
|
||||
|
||||
guard *guard.Guard
|
||||
semaphore *semaphoregroup.SemaphoreGroup
|
||||
wg sync.WaitGroup
|
||||
guard *guard.Guard
|
||||
wg sync.WaitGroup
|
||||
|
||||
// debug purpose
|
||||
dumpState *stateDump
|
||||
@@ -139,7 +135,6 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
iFaceDiscover: services.IFaceDiscover,
|
||||
relayManager: services.RelayManager,
|
||||
srWatcher: services.SrWatcher,
|
||||
semaphore: services.Semaphore,
|
||||
statusRelay: worker.NewAtomicStatus(),
|
||||
statusICE: worker.NewAtomicStatus(),
|
||||
dumpState: dumpState,
|
||||
@@ -154,15 +149,10 @@ func NewConn(config ConnConfig, services ServiceDependencies) (*Conn, error) {
|
||||
// It will try to establish a connection using ICE and in parallel with relay. The higher priority connection type will
|
||||
// be used.
|
||||
func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
if err := conn.semaphore.Add(engineCtx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn.mu.Lock()
|
||||
defer conn.mu.Unlock()
|
||||
|
||||
if conn.opened {
|
||||
conn.semaphore.Done()
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -173,7 +163,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
relayIsSupportedLocally := conn.workerRelay.RelayIsSupportedLocally()
|
||||
workerICE, err := NewWorkerICE(conn.ctx, conn.Log, conn.config, conn, conn.signaler, conn.iFaceDiscover, conn.statusRecorder, relayIsSupportedLocally)
|
||||
if err != nil {
|
||||
conn.semaphore.Done()
|
||||
return err
|
||||
}
|
||||
conn.workerICE = workerICE
|
||||
@@ -207,10 +196,6 @@ func (conn *Conn) Open(engineCtx context.Context) error {
|
||||
conn.wg.Add(1)
|
||||
go func() {
|
||||
defer conn.wg.Done()
|
||||
|
||||
conn.waitInitialRandomSleepTime(conn.ctx)
|
||||
conn.semaphore.Done()
|
||||
|
||||
conn.guard.Start(conn.ctx, conn.onGuardEvent)
|
||||
}()
|
||||
conn.opened = true
|
||||
@@ -434,14 +419,14 @@ func (conn *Conn) onICEStateDisconnected(sessionChanged bool) {
|
||||
conn.resetEndpoint()
|
||||
}
|
||||
|
||||
// todo consider to move after the ConfigureWGEndpoint
|
||||
conn.wgProxyRelay.Work()
|
||||
|
||||
presharedKey := conn.presharedKey(conn.rosenpassRemoteKey)
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := conn.endpointUpdater.SwitchWGEndpoint(conn.wgProxyRelay.EndpointAddr(), presharedKey); err != nil {
|
||||
conn.Log.Errorf("failed to switch to relay conn: %v", err)
|
||||
}
|
||||
|
||||
conn.wgProxyRelay.Work()
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
} else {
|
||||
conn.Log.Infof("ICE disconnected, do not switch to Relay. Reset priority to: %s", conntype.None.String())
|
||||
@@ -503,20 +488,22 @@ func (conn *Conn) onRelayConnectionIsReady(rci RelayConnInfo) {
|
||||
return
|
||||
}
|
||||
|
||||
wgProxy.Work()
|
||||
presharedKey := conn.presharedKey(rci.rosenpassPubKey)
|
||||
controller := isController(conn.config)
|
||||
|
||||
if controller {
|
||||
wgProxy.Work()
|
||||
}
|
||||
conn.enableWgWatcherIfNeeded()
|
||||
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), presharedKey); err != nil {
|
||||
if err := conn.endpointUpdater.ConfigureWGEndpoint(wgProxy.EndpointAddr(), conn.presharedKey(rci.rosenpassPubKey)); err != nil {
|
||||
if err := wgProxy.CloseConn(); err != nil {
|
||||
conn.Log.Warnf("Failed to close relay connection: %v", err)
|
||||
}
|
||||
conn.Log.Errorf("Failed to update WireGuard peer configuration: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
wgConfigWorkaround()
|
||||
if !controller {
|
||||
wgProxy.Work()
|
||||
}
|
||||
conn.rosenpassRemoteKey = rci.rosenpassPubKey
|
||||
conn.currentConnPriority = conntype.Relay
|
||||
conn.statusRelay.SetConnected()
|
||||
@@ -668,19 +655,6 @@ func (conn *Conn) doOnConnected(remoteRosenpassPubKey []byte, remoteRosenpassAdd
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) waitInitialRandomSleepTime(ctx context.Context) {
|
||||
maxWait := 300
|
||||
duration := time.Duration(rand.Intn(maxWait)) * time.Millisecond
|
||||
|
||||
timeout := time.NewTimer(duration)
|
||||
defer timeout.Stop()
|
||||
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
case <-timeout.C:
|
||||
}
|
||||
}
|
||||
|
||||
func (conn *Conn) isRelayed() bool {
|
||||
switch conn.currentConnPriority {
|
||||
case conntype.Relay, conntype.ICETurn:
|
||||
@@ -877,9 +851,3 @@ func isController(config ConnConfig) bool {
|
||||
func isRosenpassEnabled(remoteRosenpassPubKey []byte) bool {
|
||||
return remoteRosenpassPubKey != nil
|
||||
}
|
||||
|
||||
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
||||
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
||||
func wgConfigWorkaround() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -15,7 +15,6 @@ import (
|
||||
"github.com/netbirdio/netbird/client/internal/peer/ice"
|
||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||
"github.com/netbirdio/netbird/util"
|
||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
||||
)
|
||||
|
||||
var testDispatcher = dispatcher.NewConnectionDispatcher()
|
||||
@@ -53,7 +52,6 @@ func TestConn_GetKey(t *testing.T) {
|
||||
|
||||
sd := ServiceDependencies{
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
@@ -71,7 +69,6 @@ func TestConn_OnRemoteOffer(t *testing.T) {
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
@@ -110,7 +107,6 @@ func TestConn_OnRemoteAnswer(t *testing.T) {
|
||||
sd := ServiceDependencies{
|
||||
StatusRecorder: NewRecorder("https://mgm"),
|
||||
SrWatcher: swWatcher,
|
||||
Semaphore: semaphoregroup.NewSemaphoreGroup(1),
|
||||
PeerConnDispatcher: testDispatcher,
|
||||
}
|
||||
conn, err := NewConn(connConf, sd)
|
||||
|
||||
@@ -34,28 +34,27 @@ func NewEndpointUpdater(log *logrus.Entry, wgConfig WgConfig, initiator bool) *E
|
||||
}
|
||||
}
|
||||
|
||||
// ConfigureWGEndpoint sets up the WireGuard endpoint configuration.
|
||||
// The initiator immediately configures the endpoint, while the non-initiator
|
||||
// waits for a fallback period before configuring to avoid handshake congestion.
|
||||
func (e *EndpointUpdater) ConfigureWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
if e.initiator {
|
||||
e.log.Debugf("configure up WireGuard as initiatr")
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
e.log.Debugf("configure up WireGuard as initiator")
|
||||
return e.configureAsInitiator(addr, presharedKey)
|
||||
}
|
||||
|
||||
e.log.Debugf("configure up WireGuard as responder")
|
||||
return e.configureAsResponder(addr, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) SwitchWGEndpoint(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
return e.updateWireGuardPeer(nil, presharedKey)
|
||||
return e.updateWireGuardPeer(addr, presharedKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
@@ -67,9 +66,37 @@ func (e *EndpointUpdater) RemoveWgPeer() error {
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) RemoveEndpointAddress() error {
|
||||
e.mu.Lock()
|
||||
defer e.mu.Unlock()
|
||||
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return e.wgConfig.WgInterface.RemoveEndpointAddress(e.wgConfig.RemoteKey)
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) configureAsInitiator(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
if err := e.updateWireGuardPeer(addr, presharedKey); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) configureAsResponder(addr *net.UDPAddr, presharedKey *wgtypes.Key) error {
|
||||
// prevent to run new update while cancel the previous update
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
|
||||
e.log.Debugf("configure up WireGuard and wait for handshake")
|
||||
var ctx context.Context
|
||||
ctx, e.cancelFunc = context.WithCancel(context.Background())
|
||||
e.updateWg.Add(1)
|
||||
go e.scheduleDelayedUpdate(ctx, addr, presharedKey)
|
||||
|
||||
if err := e.updateWireGuardPeer(nil, presharedKey); err != nil {
|
||||
e.waitForCloseTheDelayedUpdate()
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (e *EndpointUpdater) waitForCloseTheDelayedUpdate() {
|
||||
if e.cancelFunc == nil {
|
||||
return
|
||||
@@ -105,3 +132,9 @@ func (e *EndpointUpdater) updateWireGuardPeer(endpoint *net.UDPAddr, presharedKe
|
||||
presharedKey,
|
||||
)
|
||||
}
|
||||
|
||||
// wgConfigWorkaround is a workaround for the issue with WireGuard configuration update
|
||||
// When update a peer configuration in near to each other time, the second update can be ignored by WireGuard
|
||||
func wgConfigWorkaround() {
|
||||
time.Sleep(100 * time.Millisecond)
|
||||
}
|
||||
|
||||
@@ -198,7 +198,7 @@ func getConfigDirForUser(username string) (string, error) {
|
||||
|
||||
configDir := filepath.Join(DefaultConfigPathDir, username)
|
||||
if _, err := os.Stat(configDir); os.IsNotExist(err) {
|
||||
if err := os.MkdirAll(configDir, 0600); err != nil {
|
||||
if err := os.MkdirAll(configDir, 0700); err != nil {
|
||||
return "", err
|
||||
}
|
||||
}
|
||||
@@ -206,9 +206,15 @@ func getConfigDirForUser(username string) (string, error) {
|
||||
return configDir, nil
|
||||
}
|
||||
|
||||
func fileExists(path string) bool {
|
||||
func fileExists(path string) (bool, error) {
|
||||
_, err := os.Stat(path)
|
||||
return !os.IsNotExist(err)
|
||||
if err == nil {
|
||||
return true, nil
|
||||
}
|
||||
if os.IsNotExist(err) {
|
||||
return false, nil
|
||||
}
|
||||
return false, err
|
||||
}
|
||||
|
||||
// createNewConfig creates a new config generating a new Wireguard key and saving to file
|
||||
@@ -635,7 +641,11 @@ func isPreSharedKeyHidden(preSharedKey *string) bool {
|
||||
|
||||
// UpdateConfig update existing configuration according to input configuration and return with the configuration
|
||||
func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
return nil, fmt.Errorf("config file %s does not exist", input.ConfigPath)
|
||||
}
|
||||
|
||||
@@ -644,7 +654,11 @@ func UpdateConfig(input ConfigInput) (*Config, error) {
|
||||
|
||||
// UpdateOrCreateConfig reads existing config or generates a new one
|
||||
func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
@@ -657,7 +671,7 @@ func UpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if isPreSharedKeyHidden(input.PreSharedKey) {
|
||||
input.PreSharedKey = nil
|
||||
}
|
||||
err := util.EnforcePermission(input.ConfigPath)
|
||||
err = util.EnforcePermission(input.ConfigPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
}
|
||||
@@ -784,7 +798,12 @@ func ReadConfig(configPath string) (*Config, error) {
|
||||
|
||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||
func readConfig(configPath string, createIfMissing bool) (*Config, error) {
|
||||
if fileExists(configPath) {
|
||||
configExists, err := fileExists(configPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
|
||||
if configExists {
|
||||
err := util.EnforcePermission(configPath)
|
||||
if err != nil {
|
||||
log.Errorf("failed to enforce permission on config dir: %v", err)
|
||||
@@ -831,7 +850,11 @@ func DirectWriteOutConfig(path string, config *Config) error {
|
||||
// DirectUpdateOrCreateConfig is like UpdateOrCreateConfig but uses direct (non-atomic) writes.
|
||||
// Use this on platforms where atomic writes are blocked (e.g., tvOS sandbox).
|
||||
func DirectUpdateOrCreateConfig(input ConfigInput) (*Config, error) {
|
||||
if !fileExists(input.ConfigPath) {
|
||||
configExists, err := fileExists(input.ConfigPath)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if config file exists: %w", err)
|
||||
}
|
||||
if !configExists {
|
||||
log.Infof("generating new config %s", input.ConfigPath)
|
||||
cfg, err := createNewConfig(input)
|
||||
if err != nil {
|
||||
|
||||
@@ -256,7 +256,11 @@ func (s *ServiceManager) AddProfile(profileName, username string) error {
|
||||
}
|
||||
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
if fileExists(profPath) {
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if profileExists {
|
||||
return ErrProfileAlreadyExists
|
||||
}
|
||||
|
||||
@@ -285,7 +289,11 @@ func (s *ServiceManager) RemoveProfile(profileName, username string) error {
|
||||
return fmt.Errorf("cannot remove profile with reserved name: %s", defaultProfileName)
|
||||
}
|
||||
profPath := filepath.Join(configDir, profileName+".json")
|
||||
if !fileExists(profPath) {
|
||||
profileExists, err := fileExists(profPath)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if profile exists: %w", err)
|
||||
}
|
||||
if !profileExists {
|
||||
return ErrProfileNotFound
|
||||
}
|
||||
|
||||
|
||||
@@ -20,7 +20,11 @@ func (pm *ProfileManager) GetProfileState(profileName string) (*ProfileState, er
|
||||
}
|
||||
|
||||
stateFile := filepath.Join(configDir, profileName+".state.json")
|
||||
if !fileExists(stateFile) {
|
||||
stateFileExists, err := fileExists(stateFile)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to check if profile state file exists: %w", err)
|
||||
}
|
||||
if !stateFileExists {
|
||||
return nil, errors.New("profile state file does not exist")
|
||||
}
|
||||
|
||||
|
||||
@@ -3,7 +3,9 @@ package client
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"net"
|
||||
"reflect"
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
@@ -263,8 +265,14 @@ func (w *Watcher) watchPeerStatusChanges(ctx context.Context, peerKey string, pe
|
||||
case <-closer:
|
||||
return
|
||||
case routerStates := <-subscription.Events():
|
||||
peerStateUpdate <- routerStates
|
||||
log.Debugf("triggered route state update for Peer: %s", peerKey)
|
||||
select {
|
||||
case peerStateUpdate <- routerStates:
|
||||
log.Debugf("triggered route state update for Peer: %s", peerKey)
|
||||
case <-ctx.Done():
|
||||
return
|
||||
case <-closer:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -558,7 +566,7 @@ func HandlerFromRoute(params common.HandlerParams) RouteHandler {
|
||||
return dnsinterceptor.New(params)
|
||||
case handlerTypeDynamic:
|
||||
dns := nbdns.NewServiceViaMemory(params.WgInterface)
|
||||
dnsAddr := fmt.Sprintf("%s:%d", dns.RuntimeIP(), dns.RuntimePort())
|
||||
dnsAddr := net.JoinHostPort(dns.RuntimeIP().String(), strconv.Itoa(dns.RuntimePort()))
|
||||
return dynamic.NewRoute(params, dnsAddr)
|
||||
default:
|
||||
return static.NewRoute(params)
|
||||
|
||||
@@ -4,8 +4,10 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/netip"
|
||||
"runtime"
|
||||
"strconv"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
@@ -249,7 +251,7 @@ func (d *DnsInterceptor) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||
r.MsgHdr.AuthenticatedData = true
|
||||
}
|
||||
|
||||
upstream := fmt.Sprintf("%s:%d", upstreamIP.String(), uint16(d.forwarderPort.Load()))
|
||||
upstream := net.JoinHostPort(upstreamIP.String(), strconv.FormatUint(uint64(d.forwarderPort.Load()), 10))
|
||||
ctx, cancel := context.WithTimeout(context.Background(), dnsTimeout)
|
||||
defer cancel()
|
||||
|
||||
@@ -351,6 +353,11 @@ func (d *DnsInterceptor) writeMsg(w dns.ResponseWriter, r *dns.Msg, logger *log.
|
||||
logger.Errorf("failed to update domain prefixes: %v", err)
|
||||
}
|
||||
|
||||
// Allow time for route changes to be applied before sending
|
||||
// the DNS response (relevant on iOS where setTunnelNetworkSettings
|
||||
// is asynchronous).
|
||||
waitForRouteSettlement(logger)
|
||||
|
||||
d.replaceIPsInDNSResponse(r, newPrefixes, logger)
|
||||
}
|
||||
}
|
||||
|
||||
20
client/internal/routemanager/dnsinterceptor/handler_ios.go
Normal file
20
client/internal/routemanager/dnsinterceptor/handler_ios.go
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build ios
|
||||
|
||||
package dnsinterceptor
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
)
|
||||
|
||||
const routeSettleDelay = 500 * time.Millisecond
|
||||
|
||||
// waitForRouteSettlement introduces a short delay on iOS to allow
|
||||
// setTunnelNetworkSettings to apply route changes before the DNS
|
||||
// response reaches the application. Without this, the first request
|
||||
// to a newly resolved domain may bypass the tunnel.
|
||||
func waitForRouteSettlement(logger *log.Entry) {
|
||||
logger.Tracef("waiting %v for iOS route settlement", routeSettleDelay)
|
||||
time.Sleep(routeSettleDelay)
|
||||
}
|
||||
@@ -0,0 +1,12 @@
|
||||
//go:build !ios
|
||||
|
||||
package dnsinterceptor
|
||||
|
||||
import log "github.com/sirupsen/logrus"
|
||||
|
||||
func waitForRouteSettlement(_ *log.Entry) {
|
||||
// No-op on non-iOS platforms: route changes are applied synchronously by
|
||||
// the kernel, so no settlement delay is needed before the DNS response
|
||||
// reaches the application. The delay is only required on iOS where
|
||||
// setTunnelNetworkSettings applies routes asynchronously.
|
||||
}
|
||||
80
client/internal/sleep/handler/handler.go
Normal file
80
client/internal/sleep/handler/handler.go
Normal file
@@ -0,0 +1,80 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type Agent interface {
|
||||
Up(ctx context.Context) error
|
||||
Down(ctx context.Context) error
|
||||
Status() (internal.StatusType, error)
|
||||
}
|
||||
|
||||
type SleepHandler struct {
|
||||
agent Agent
|
||||
|
||||
mu sync.Mutex
|
||||
// sleepTriggeredDown indicates whether the sleep handler triggered the last client down, to avoid unnecessary up on wake
|
||||
sleepTriggeredDown bool
|
||||
}
|
||||
|
||||
func New(agent Agent) *SleepHandler {
|
||||
return &SleepHandler{
|
||||
agent: agent,
|
||||
}
|
||||
}
|
||||
|
||||
func (s *SleepHandler) HandleWakeUp(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
if !s.sleepTriggeredDown {
|
||||
log.Info("skipping up because wasn't sleep down")
|
||||
return nil
|
||||
}
|
||||
|
||||
// avoid other wakeup runs if sleep didn't make the computer sleep
|
||||
s.sleepTriggeredDown = false
|
||||
|
||||
log.Info("running up after wake up")
|
||||
err := s.agent.Up(ctx)
|
||||
if err != nil {
|
||||
log.Errorf("running up failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
log.Info("running up command executed successfully")
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *SleepHandler) HandleSleep(ctx context.Context) error {
|
||||
s.mu.Lock()
|
||||
defer s.mu.Unlock()
|
||||
|
||||
status, err := s.agent.Status()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if status != internal.StatusConnecting && status != internal.StatusConnected {
|
||||
log.Infof("skipping setting the agent down because status is %s", status)
|
||||
return nil
|
||||
}
|
||||
|
||||
log.Info("running down after system started sleeping")
|
||||
|
||||
if err = s.agent.Down(ctx); err != nil {
|
||||
log.Errorf("running down failed: %v", err)
|
||||
return err
|
||||
}
|
||||
|
||||
s.sleepTriggeredDown = true
|
||||
|
||||
log.Info("running down executed successfully")
|
||||
return nil
|
||||
}
|
||||
153
client/internal/sleep/handler/handler_test.go
Normal file
153
client/internal/sleep/handler/handler_test.go
Normal file
@@ -0,0 +1,153 @@
|
||||
package handler
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal"
|
||||
)
|
||||
|
||||
type mockAgent struct {
|
||||
upErr error
|
||||
downErr error
|
||||
statusErr error
|
||||
status internal.StatusType
|
||||
upCalls int
|
||||
}
|
||||
|
||||
func (m *mockAgent) Up(_ context.Context) error {
|
||||
m.upCalls++
|
||||
return m.upErr
|
||||
}
|
||||
|
||||
func (m *mockAgent) Down(_ context.Context) error {
|
||||
return m.downErr
|
||||
}
|
||||
|
||||
func (m *mockAgent) Status() (internal.StatusType, error) {
|
||||
return m.status, m.statusErr
|
||||
}
|
||||
|
||||
func newHandler(status internal.StatusType) (*SleepHandler, *mockAgent) {
|
||||
agent := &mockAgent{status: status}
|
||||
return New(agent), agent
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SkipsWhenFlagFalse(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 0, agent.upCalls, "Up should not be called when flag is false")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ResetsFlagBeforeUp(t *testing.T) {
|
||||
h, _ := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
// Even if Up fails, flag should be reset
|
||||
_ = h.HandleWakeUp(context.Background())
|
||||
|
||||
assert.False(t, h.sleepTriggeredDown, "flag must be reset before calling Up")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_CallsUpWhenFlagSet(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, agent.upCalls)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_ReturnsErrorFromUp(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
agent.upErr = errors.New("up failed")
|
||||
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.upErr)
|
||||
assert.False(t, h.sleepTriggeredDown, "flag should still be reset even when Up fails")
|
||||
}
|
||||
|
||||
func TestHandleWakeUp_SecondCallIsNoOp(t *testing.T) {
|
||||
h, agent := newHandler(internal.StatusIdle)
|
||||
h.sleepTriggeredDown = true
|
||||
|
||||
_ = h.HandleWakeUp(context.Background())
|
||||
err := h.HandleWakeUp(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, 1, agent.upCalls, "second wakeup should be no-op")
|
||||
}
|
||||
|
||||
func TestHandleSleep_SkipsForNonActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Idle", internal.StatusIdle},
|
||||
{"NeedsLogin", internal.StatusNeedsLogin},
|
||||
{"LoginFailed", internal.StatusLoginFailed},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, _ := newHandler(tt.status)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ProceedsForActiveStates(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
status internal.StatusType
|
||||
}{
|
||||
{"Connecting", internal.StatusConnecting},
|
||||
{"Connected", internal.StatusConnected},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
h, _ := newHandler(tt.status)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
require.NoError(t, err)
|
||||
assert.True(t, h.sleepTriggeredDown)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestHandleSleep_ReturnsErrorFromStatus(t *testing.T) {
|
||||
agent := &mockAgent{statusErr: errors.New("status error")}
|
||||
h := New(agent)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.statusErr)
|
||||
assert.False(t, h.sleepTriggeredDown)
|
||||
}
|
||||
|
||||
func TestHandleSleep_ReturnsErrorFromDown(t *testing.T) {
|
||||
agent := &mockAgent{status: internal.StatusConnected, downErr: errors.New("down failed")}
|
||||
h := New(agent)
|
||||
|
||||
err := h.HandleSleep(context.Background())
|
||||
|
||||
assert.ErrorIs(t, err, agent.downErr)
|
||||
assert.False(t, h.sleepTriggeredDown, "flag should not be set when Down fails")
|
||||
}
|
||||
@@ -1,214 +0,0 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
type versionUpdateMock struct {
|
||||
latestVersion *v.Version
|
||||
onUpdate func()
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StopWatch() {}
|
||||
|
||||
func (v versionUpdateMock) SetDaemonVersion(newVersion string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (v *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
|
||||
v.onUpdate = updateFn
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) LatestVersion() *v.Version {
|
||||
return v.latestVersion
|
||||
}
|
||||
|
||||
func (v versionUpdateMock) StartFetcher() {}
|
||||
|
||||
func Test_LatestVersion(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
initialLatestVersion *v.Version
|
||||
latestVersion *v.Version
|
||||
shouldUpdateInit bool
|
||||
shouldUpdateLater bool
|
||||
}{
|
||||
{
|
||||
name: "Should only trigger update once due to time between triggers being < 5 Minutes",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||
shouldUpdateInit: true,
|
||||
shouldUpdateLater: false,
|
||||
},
|
||||
{
|
||||
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: nil,
|
||||
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
shouldUpdateInit: false,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, c := range testMatrix {
|
||||
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
|
||||
m.update = mockUpdate
|
||||
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion("latest")
|
||||
var triggeredInit bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.initialLatestVersion.String() {
|
||||
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredInit = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredInit = false
|
||||
}
|
||||
if triggeredInit != c.shouldUpdateInit {
|
||||
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||
}
|
||||
|
||||
mockUpdate.latestVersion = c.latestVersion
|
||||
mockUpdate.onUpdate()
|
||||
|
||||
var triggeredLater bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
}
|
||||
triggeredLater = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
triggeredLater = false
|
||||
}
|
||||
if triggeredLater != c.shouldUpdateLater {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleUpdate(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
latestVersion *v.Version
|
||||
expectedVersion string
|
||||
shouldUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "Update to a specific version should update regardless of if latestVersion is available yet",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.56.0",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if version matches",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.55.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to specific version should not update if current version is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.54.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should update if latest is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Update to latest version should not update if latest == current",
|
||||
daemonVersion: "0.56.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if daemon version is invalid",
|
||||
daemonVersion: "development",
|
||||
latestVersion: v.Must(v.NewSemver("1.0.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expecting latest and latest version is unavailable",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expected version is invalid",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "development",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
}
|
||||
for idx, c := range testMatrix {
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
m, _ := newManager(peer.NewRecorder(""), statemanager.New(tmpFile))
|
||||
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
|
||||
targetVersionChan := make(chan string, 1)
|
||||
|
||||
m.triggerUpdateFn = func(ctx context.Context, targetVersion string) error {
|
||||
targetVersionChan <- targetVersion
|
||||
return nil
|
||||
}
|
||||
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetVersion(c.expectedVersion)
|
||||
|
||||
var updateTriggered bool
|
||||
select {
|
||||
case targetVersion := <-targetVersionChan:
|
||||
if c.expectedVersion == "latest" && targetVersion != c.latestVersion.String() {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), targetVersion)
|
||||
} else if c.expectedVersion != "latest" && targetVersion != c.expectedVersion {
|
||||
t.Errorf("%s: Update version mismatch, expected %v, got %v", c.name, c.expectedVersion, targetVersion)
|
||||
}
|
||||
updateTriggered = true
|
||||
case <-time.After(10 * time.Millisecond):
|
||||
updateTriggered = false
|
||||
}
|
||||
|
||||
if updateTriggered != c.shouldUpdate {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
|
||||
}
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
@@ -1,39 +0,0 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package updatemanager
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// Manager is a no-op stub for unsupported platforms
|
||||
type Manager struct{}
|
||||
|
||||
// NewManager returns a no-op manager for unsupported platforms
|
||||
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||
return nil, fmt.Errorf("update manager is not supported on this platform")
|
||||
}
|
||||
|
||||
// CheckUpdateSuccess is a no-op on unsupported platforms
|
||||
func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// Start is a no-op on unsupported platforms
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// SetVersion is a no-op on unsupported platforms
|
||||
func (m *Manager) SetVersion(expectedVersion string) {
|
||||
// no-op
|
||||
}
|
||||
|
||||
// Stop is a no-op on unsupported platforms
|
||||
func (m *Manager) Stop() {
|
||||
// no-op
|
||||
}
|
||||
@@ -1,4 +1,4 @@
|
||||
// Package updatemanager provides automatic update management for the NetBird client.
|
||||
// Package updater provides automatic update management for the NetBird client.
|
||||
// It monitors for new versions, handles update triggers from management server directives,
|
||||
// and orchestrates the download and installation of client updates.
|
||||
//
|
||||
@@ -32,4 +32,4 @@
|
||||
//
|
||||
// This enables verification of successful updates and appropriate user notification
|
||||
// after the client restarts with the new version.
|
||||
package updatemanager
|
||||
package updater
|
||||
@@ -16,8 +16,8 @@ import (
|
||||
goversion "github.com/hashicorp/go-version"
|
||||
log "github.com/sirupsen/logrus"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/downloader"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/downloader"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||
)
|
||||
|
||||
type Installer struct {
|
||||
@@ -203,7 +203,10 @@ func (rh *ResultHandler) write(result Result) error {
|
||||
|
||||
func (rh *ResultHandler) cleanup() error {
|
||||
err := os.Remove(rh.resultFile)
|
||||
if err != nil && !os.IsNotExist(err) {
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
log.Debugf("delete installer result file: %s", rh.resultFile)
|
||||
@@ -1,12 +1,9 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package updatemanager
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
@@ -15,7 +12,7 @@ import (
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
||||
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
"github.com/netbirdio/netbird/version"
|
||||
)
|
||||
@@ -41,6 +38,9 @@ type Manager struct {
|
||||
statusRecorder *peer.Status
|
||||
stateManager *statemanager.Manager
|
||||
|
||||
downloadOnly bool // true when no enforcement from management; notifies UI to download latest
|
||||
forceUpdate bool // true when management sets AlwaysUpdate; skips UI interaction and installs directly
|
||||
|
||||
lastTrigger time.Time
|
||||
mgmUpdateChan chan struct{}
|
||||
updateChannel chan struct{}
|
||||
@@ -53,37 +53,38 @@ type Manager struct {
|
||||
expectedVersion *v.Version
|
||||
updateToLatestVersion bool
|
||||
|
||||
// updateMutex protect update and expectedVersion fields
|
||||
pendingVersion *v.Version
|
||||
|
||||
// updateMutex protects update, expectedVersion, updateToLatestVersion,
|
||||
// downloadOnly, forceUpdate, pendingVersion, and lastTrigger fields
|
||||
updateMutex sync.Mutex
|
||||
|
||||
triggerUpdateFn func(context.Context, string) error
|
||||
// installMutex and installing guard against concurrent installation attempts
|
||||
installMutex sync.Mutex
|
||||
installing bool
|
||||
|
||||
// protect to start the service multiple times
|
||||
mu sync.Mutex
|
||||
|
||||
autoUpdateSupported func() bool
|
||||
}
|
||||
|
||||
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||
if runtime.GOOS == "darwin" {
|
||||
isBrew := !installer.TypeOfInstaller(context.Background()).Downloadable()
|
||||
if isBrew {
|
||||
log.Warnf("auto-update disabled on Home Brew installation")
|
||||
return nil, fmt.Errorf("auto-update not supported on Home Brew installation yet")
|
||||
}
|
||||
}
|
||||
return newManager(statusRecorder, stateManager)
|
||||
}
|
||||
|
||||
func newManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) (*Manager, error) {
|
||||
// NewManager creates a new update manager. The manager is single-use: once Stop() is called, it cannot be restarted.
|
||||
func NewManager(statusRecorder *peer.Status, stateManager *statemanager.Manager) *Manager {
|
||||
manager := &Manager{
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
mgmUpdateChan: make(chan struct{}, 1),
|
||||
updateChannel: make(chan struct{}, 1),
|
||||
currentVersion: version.NetbirdVersion(),
|
||||
update: version.NewUpdate("nb/client"),
|
||||
statusRecorder: statusRecorder,
|
||||
stateManager: stateManager,
|
||||
mgmUpdateChan: make(chan struct{}, 1),
|
||||
updateChannel: make(chan struct{}, 1),
|
||||
currentVersion: version.NetbirdVersion(),
|
||||
update: version.NewUpdate("nb/client"),
|
||||
downloadOnly: true,
|
||||
autoUpdateSupported: isAutoUpdateSupported,
|
||||
}
|
||||
manager.triggerUpdateFn = manager.triggerUpdate
|
||||
|
||||
stateManager.RegisterState(&UpdateState{})
|
||||
|
||||
return manager, nil
|
||||
return manager
|
||||
}
|
||||
|
||||
// CheckUpdateSuccess checks if the update was successful and send a notification.
|
||||
@@ -124,8 +125,10 @@ func (m *Manager) CheckUpdateSuccess(ctx context.Context) {
|
||||
}
|
||||
|
||||
func (m *Manager) Start(ctx context.Context) {
|
||||
log.Infof("starting update manager")
|
||||
m.mu.Lock()
|
||||
defer m.mu.Unlock()
|
||||
if m.cancel != nil {
|
||||
log.Errorf("Manager already started")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -142,13 +145,32 @@ func (m *Manager) Start(ctx context.Context) {
|
||||
m.cancel = cancel
|
||||
|
||||
m.wg.Add(1)
|
||||
go m.updateLoop(ctx)
|
||||
go func() {
|
||||
defer m.wg.Done()
|
||||
m.updateLoop(ctx)
|
||||
}()
|
||||
}
|
||||
|
||||
func (m *Manager) SetVersion(expectedVersion string) {
|
||||
log.Infof("set expected agent version for upgrade: %s", expectedVersion)
|
||||
if m.cancel == nil {
|
||||
log.Errorf("manager not started")
|
||||
func (m *Manager) SetDownloadOnly() {
|
||||
m.updateMutex.Lock()
|
||||
m.downloadOnly = true
|
||||
m.forceUpdate = false
|
||||
m.expectedVersion = nil
|
||||
m.updateToLatestVersion = false
|
||||
m.lastTrigger = time.Time{}
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
select {
|
||||
case m.mgmUpdateChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
func (m *Manager) SetVersion(expectedVersion string, forceUpdate bool) {
|
||||
log.Infof("expected version changed to %s, force update: %t", expectedVersion, forceUpdate)
|
||||
|
||||
if !m.autoUpdateSupported() {
|
||||
log.Warnf("auto-update not supported on this platform")
|
||||
return
|
||||
}
|
||||
|
||||
@@ -159,6 +181,7 @@ func (m *Manager) SetVersion(expectedVersion string) {
|
||||
log.Errorf("empty expected version provided")
|
||||
m.expectedVersion = nil
|
||||
m.updateToLatestVersion = false
|
||||
m.downloadOnly = true
|
||||
return
|
||||
}
|
||||
|
||||
@@ -178,12 +201,97 @@ func (m *Manager) SetVersion(expectedVersion string) {
|
||||
m.updateToLatestVersion = false
|
||||
}
|
||||
|
||||
m.lastTrigger = time.Time{}
|
||||
m.downloadOnly = false
|
||||
m.forceUpdate = forceUpdate
|
||||
|
||||
select {
|
||||
case m.mgmUpdateChan <- struct{}{}:
|
||||
default:
|
||||
}
|
||||
}
|
||||
|
||||
// Install triggers the installation of the pending version. It is called when the user clicks the install button in the UI.
|
||||
func (m *Manager) Install(ctx context.Context) error {
|
||||
if !m.autoUpdateSupported() {
|
||||
return fmt.Errorf("auto-update not supported on this platform")
|
||||
}
|
||||
|
||||
m.updateMutex.Lock()
|
||||
pending := m.pendingVersion
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
if pending == nil {
|
||||
return fmt.Errorf("no pending version to install")
|
||||
}
|
||||
|
||||
return m.tryInstall(ctx, pending)
|
||||
}
|
||||
|
||||
// tryInstall ensures only one installation runs at a time. Concurrent callers
|
||||
// receive an error immediately rather than queuing behind a running install.
|
||||
func (m *Manager) tryInstall(ctx context.Context, targetVersion *v.Version) error {
|
||||
m.installMutex.Lock()
|
||||
if m.installing {
|
||||
m.installMutex.Unlock()
|
||||
return fmt.Errorf("installation already in progress")
|
||||
}
|
||||
m.installing = true
|
||||
m.installMutex.Unlock()
|
||||
|
||||
defer func() {
|
||||
m.installMutex.Lock()
|
||||
m.installing = false
|
||||
m.installMutex.Unlock()
|
||||
}()
|
||||
|
||||
return m.install(ctx, targetVersion)
|
||||
}
|
||||
|
||||
// NotifyUI re-publishes the current update state to a newly connected UI client.
|
||||
// Only needed for download-only mode where the latest version is already cached
|
||||
// NotifyUI re-publishes the current update state so a newly connected UI gets the info.
|
||||
func (m *Manager) NotifyUI() {
|
||||
m.updateMutex.Lock()
|
||||
if m.update == nil {
|
||||
m.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
downloadOnly := m.downloadOnly
|
||||
pendingVersion := m.pendingVersion
|
||||
latestVersion := m.update.LatestVersion()
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
if downloadOnly {
|
||||
if latestVersion == nil {
|
||||
return
|
||||
}
|
||||
currentVersion, err := v.NewVersion(m.currentVersion)
|
||||
if err != nil || currentVersion.GreaterThanOrEqual(latestVersion) {
|
||||
return
|
||||
}
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"New version available",
|
||||
"",
|
||||
map[string]string{"new_version_available": latestVersion.String()},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if pendingVersion != nil {
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"New version available",
|
||||
"",
|
||||
map[string]string{"new_version_available": pendingVersion.String(), "enforced": "true"},
|
||||
)
|
||||
}
|
||||
}
|
||||
|
||||
// Stop is not used at the moment because it fully depends on the daemon. In a future refactor it may make sense to use it.
|
||||
func (m *Manager) Stop() {
|
||||
if m.cancel == nil {
|
||||
return
|
||||
@@ -214,8 +322,6 @@ func (m *Manager) onContextCancel() {
|
||||
}
|
||||
|
||||
func (m *Manager) updateLoop(ctx context.Context) {
|
||||
defer m.wg.Done()
|
||||
|
||||
for {
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
@@ -239,55 +345,89 @@ func (m *Manager) handleUpdate(ctx context.Context) {
|
||||
return
|
||||
}
|
||||
|
||||
expectedVersion := m.expectedVersion
|
||||
useLatest := m.updateToLatestVersion
|
||||
downloadOnly := m.downloadOnly
|
||||
forceUpdate := m.forceUpdate
|
||||
curLatestVersion := m.update.LatestVersion()
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
switch {
|
||||
// Resolve "latest" to actual version
|
||||
case useLatest:
|
||||
// Download-only mode or resolve "latest" to actual version
|
||||
case downloadOnly, m.updateToLatestVersion:
|
||||
if curLatestVersion == nil {
|
||||
log.Tracef("latest version not fetched yet")
|
||||
m.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
updateVersion = curLatestVersion
|
||||
// Update to specific version
|
||||
case expectedVersion != nil:
|
||||
updateVersion = expectedVersion
|
||||
// Install to specific version
|
||||
case m.expectedVersion != nil:
|
||||
updateVersion = m.expectedVersion
|
||||
default:
|
||||
log.Debugf("no expected version information set")
|
||||
m.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
log.Debugf("checking update option, current version: %s, target version: %s", m.currentVersion, updateVersion)
|
||||
if !m.shouldUpdate(updateVersion) {
|
||||
if !m.shouldUpdate(updateVersion, forceUpdate) {
|
||||
m.updateMutex.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
m.lastTrigger = time.Now()
|
||||
log.Infof("Auto-update triggered, current version: %s, target version: %s", m.currentVersion, updateVersion)
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Automatically updating client",
|
||||
"Your client version is older than auto-update version set in Management, updating client now.",
|
||||
nil,
|
||||
)
|
||||
log.Infof("new version available: %s", updateVersion)
|
||||
|
||||
if !downloadOnly && !forceUpdate {
|
||||
m.pendingVersion = updateVersion
|
||||
}
|
||||
m.updateMutex.Unlock()
|
||||
|
||||
if downloadOnly {
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"New version available",
|
||||
"",
|
||||
map[string]string{"new_version_available": updateVersion.String()},
|
||||
)
|
||||
return
|
||||
}
|
||||
|
||||
if forceUpdate {
|
||||
if err := m.tryInstall(ctx, updateVersion); err != nil {
|
||||
log.Errorf("force update failed: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_INFO,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"New version available",
|
||||
"",
|
||||
map[string]string{"new_version_available": updateVersion.String(), "enforced": "true"},
|
||||
)
|
||||
}
|
||||
|
||||
func (m *Manager) install(ctx context.Context, pendingVersion *v.Version) error {
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"Updating client",
|
||||
"Installing update now.",
|
||||
nil,
|
||||
)
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_CRITICAL,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
"",
|
||||
"",
|
||||
map[string]string{"progress_window": "show", "version": updateVersion.String()},
|
||||
map[string]string{"progress_window": "show", "version": pendingVersion.String()},
|
||||
)
|
||||
|
||||
updateState := UpdateState{
|
||||
PreUpdateVersion: m.currentVersion,
|
||||
TargetVersion: updateVersion.String(),
|
||||
TargetVersion: pendingVersion.String(),
|
||||
}
|
||||
|
||||
if err := m.stateManager.UpdateState(updateState); err != nil {
|
||||
log.Warnf("failed to update state: %v", err)
|
||||
} else {
|
||||
@@ -296,8 +436,9 @@ func (m *Manager) handleUpdate(ctx context.Context) {
|
||||
}
|
||||
}
|
||||
|
||||
if err := m.triggerUpdateFn(ctx, updateVersion.String()); err != nil {
|
||||
log.Errorf("Error triggering auto-update: %v", err)
|
||||
inst := installer.New()
|
||||
if err := inst.RunInstallation(ctx, pendingVersion.String()); err != nil {
|
||||
log.Errorf("error triggering update: %v", err)
|
||||
m.statusRecorder.PublishEvent(
|
||||
cProto.SystemEvent_ERROR,
|
||||
cProto.SystemEvent_SYSTEM,
|
||||
@@ -305,7 +446,9 @@ func (m *Manager) handleUpdate(ctx context.Context) {
|
||||
fmt.Sprintf("Auto-update failed: %v", err),
|
||||
nil,
|
||||
)
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// loadAndDeleteUpdateState loads the update state, deletes it from storage, and returns it.
|
||||
@@ -339,7 +482,7 @@ func (m *Manager) loadAndDeleteUpdateState(ctx context.Context) (*UpdateState, e
|
||||
return updateState, nil
|
||||
}
|
||||
|
||||
func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
|
||||
func (m *Manager) shouldUpdate(updateVersion *v.Version, forceUpdate bool) bool {
|
||||
if m.currentVersion == developmentVersion {
|
||||
log.Debugf("skipping auto-update, running development version")
|
||||
return false
|
||||
@@ -354,8 +497,8 @@ func (m *Manager) shouldUpdate(updateVersion *v.Version) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
if time.Since(m.lastTrigger) < 5*time.Minute {
|
||||
log.Debugf("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
|
||||
if forceUpdate && time.Since(m.lastTrigger) < 3*time.Minute {
|
||||
log.Infof("skipping auto-update, last update was %s ago", time.Since(m.lastTrigger))
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -367,8 +510,3 @@ func (m *Manager) lastResultErrReason() string {
|
||||
result := installer.NewResultHandler(inst.TempDir())
|
||||
return result.GetErrorResultReason()
|
||||
}
|
||||
|
||||
func (m *Manager) triggerUpdate(ctx context.Context, targetVersion string) error {
|
||||
inst := installer.New()
|
||||
return inst.RunInstallation(ctx, targetVersion)
|
||||
}
|
||||
111
client/internal/updater/manager_linux_test.go
Normal file
111
client/internal/updater/manager_linux_test.go
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build !windows && !darwin
|
||||
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
)
|
||||
|
||||
// On Linux, only Mode 1 (downloadOnly) is supported.
|
||||
// SetVersion is a no-op because auto-update installation is not supported.
|
||||
|
||||
func Test_LatestVersion_Linux(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
initialLatestVersion *v.Version
|
||||
latestVersion *v.Version
|
||||
shouldUpdateInit bool
|
||||
shouldUpdateLater bool
|
||||
}{
|
||||
{
|
||||
name: "Should notify again when a newer version arrives even within 5 minutes",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||
shouldUpdateInit: true,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
{
|
||||
name: "Shouldn't notify initially, but should notify as soon as latest version is fetched",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: nil,
|
||||
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
shouldUpdateInit: false,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, c := range testMatrix {
|
||||
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = mockUpdate
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.Start(context.Background())
|
||||
m.SetDownloadOnly()
|
||||
|
||||
ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
triggeredInit := ver != ""
|
||||
if enforced {
|
||||
t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name)
|
||||
}
|
||||
if triggeredInit != c.shouldUpdateInit {
|
||||
t.Errorf("%s: Initial notify mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||
}
|
||||
if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() {
|
||||
t.Errorf("%s: Initial version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver)
|
||||
}
|
||||
|
||||
mockUpdate.latestVersion = c.latestVersion
|
||||
mockUpdate.onUpdate()
|
||||
|
||||
ver, enforced = waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
triggeredLater := ver != ""
|
||||
if enforced {
|
||||
t.Errorf("%s: Linux Mode 1 must never have enforced metadata", c.name)
|
||||
}
|
||||
if triggeredLater != c.shouldUpdateLater {
|
||||
t.Errorf("%s: Later notify mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||
}
|
||||
if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() {
|
||||
t.Errorf("%s: Later version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_SetVersion_NoOp_Linux(t *testing.T) {
|
||||
// On Linux, SetVersion should be a no-op — no events fired
|
||||
tmpFile := path.Join(t.TempDir(), "update-test-noop.json")
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))}
|
||||
m.currentVersion = "1.0.0"
|
||||
m.Start(context.Background())
|
||||
m.SetVersion("1.0.1", false)
|
||||
|
||||
ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
if ver != "" {
|
||||
t.Errorf("SetVersion should be a no-op on Linux, but got event with version %s", ver)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
227
client/internal/updater/manager_test.go
Normal file
227
client/internal/updater/manager_test.go
Normal file
@@ -0,0 +1,227 @@
|
||||
//go:build windows || darwin
|
||||
|
||||
package updater
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"path"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||
cProto "github.com/netbirdio/netbird/client/proto"
|
||||
)
|
||||
|
||||
func Test_LatestVersion(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
initialLatestVersion *v.Version
|
||||
latestVersion *v.Version
|
||||
shouldUpdateInit bool
|
||||
shouldUpdateLater bool
|
||||
}{
|
||||
{
|
||||
name: "Should notify again when a newer version arrives even within 5 minutes",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
latestVersion: v.Must(v.NewSemver("1.0.2")),
|
||||
shouldUpdateInit: true,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
{
|
||||
name: "Shouldn't update initially, but should update as soon as latest version is fetched",
|
||||
daemonVersion: "1.0.0",
|
||||
initialLatestVersion: nil,
|
||||
latestVersion: v.Must(v.NewSemver("1.0.1")),
|
||||
shouldUpdateInit: false,
|
||||
shouldUpdateLater: true,
|
||||
},
|
||||
}
|
||||
|
||||
for idx, c := range testMatrix {
|
||||
mockUpdate := &versionUpdateMock{latestVersion: c.initialLatestVersion}
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = mockUpdate
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.autoUpdateSupported = func() bool { return true }
|
||||
m.Start(context.Background())
|
||||
m.SetVersion("latest", false)
|
||||
|
||||
ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
triggeredInit := ver != ""
|
||||
if triggeredInit != c.shouldUpdateInit {
|
||||
t.Errorf("%s: Initial update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateInit, triggeredInit)
|
||||
}
|
||||
if triggeredInit && c.initialLatestVersion != nil && ver != c.initialLatestVersion.String() {
|
||||
t.Errorf("%s: Initial update version mismatch, expected %v, got %v", c.name, c.initialLatestVersion.String(), ver)
|
||||
}
|
||||
|
||||
mockUpdate.latestVersion = c.latestVersion
|
||||
mockUpdate.onUpdate()
|
||||
|
||||
ver, _ = waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
triggeredLater := ver != ""
|
||||
if triggeredLater != c.shouldUpdateLater {
|
||||
t.Errorf("%s: Later update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdateLater, triggeredLater)
|
||||
}
|
||||
if triggeredLater && c.latestVersion != nil && ver != c.latestVersion.String() {
|
||||
t.Errorf("%s: Later update version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver)
|
||||
}
|
||||
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_HandleUpdate(t *testing.T) {
|
||||
testMatrix := []struct {
|
||||
name string
|
||||
daemonVersion string
|
||||
latestVersion *v.Version
|
||||
expectedVersion string
|
||||
shouldUpdate bool
|
||||
}{
|
||||
{
|
||||
name: "Install to a specific version should update regardless of if latestVersion is available yet",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.56.0",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Install to specific version should not update if version matches",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.55.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Install to specific version should not update if current version is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "0.54.0",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Install to latest version should update if latest is newer",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: true,
|
||||
},
|
||||
{
|
||||
name: "Install to latest version should not update if latest == current",
|
||||
daemonVersion: "0.56.0",
|
||||
latestVersion: v.Must(v.NewSemver("0.56.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if daemon version is invalid",
|
||||
daemonVersion: "development",
|
||||
latestVersion: v.Must(v.NewSemver("1.0.0")),
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expecting latest and latest version is unavailable",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "latest",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
{
|
||||
name: "Should not update if expected version is invalid",
|
||||
daemonVersion: "0.55.0",
|
||||
latestVersion: nil,
|
||||
expectedVersion: "development",
|
||||
shouldUpdate: false,
|
||||
},
|
||||
}
|
||||
for idx, c := range testMatrix {
|
||||
tmpFile := path.Join(t.TempDir(), fmt.Sprintf("update-test-%d.json", idx))
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = &versionUpdateMock{latestVersion: c.latestVersion}
|
||||
m.currentVersion = c.daemonVersion
|
||||
m.autoUpdateSupported = func() bool { return true }
|
||||
m.Start(context.Background())
|
||||
m.SetVersion(c.expectedVersion, false)
|
||||
|
||||
ver, _ := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
updateTriggered := ver != ""
|
||||
|
||||
if updateTriggered {
|
||||
if c.expectedVersion == "latest" && c.latestVersion != nil && ver != c.latestVersion.String() {
|
||||
t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.latestVersion.String(), ver)
|
||||
} else if c.expectedVersion != "latest" && c.expectedVersion != "development" && ver != c.expectedVersion {
|
||||
t.Errorf("%s: Version mismatch, expected %v, got %v", c.name, c.expectedVersion, ver)
|
||||
}
|
||||
}
|
||||
|
||||
if updateTriggered != c.shouldUpdate {
|
||||
t.Errorf("%s: Update trigger mismatch, expected %v, got %v", c.name, c.shouldUpdate, updateTriggered)
|
||||
}
|
||||
m.Stop()
|
||||
}
|
||||
}
|
||||
|
||||
func Test_EnforcedMetadata(t *testing.T) {
|
||||
// Mode 1 (downloadOnly): no enforced metadata
|
||||
tmpFile := path.Join(t.TempDir(), "update-test-mode1.json")
|
||||
recorder := peer.NewRecorder("")
|
||||
sub := recorder.SubscribeToEvents()
|
||||
defer recorder.UnsubscribeFromEvents(sub)
|
||||
|
||||
m := NewManager(recorder, statemanager.New(tmpFile))
|
||||
m.update = &versionUpdateMock{latestVersion: v.Must(v.NewSemver("1.0.1"))}
|
||||
m.currentVersion = "1.0.0"
|
||||
m.Start(context.Background())
|
||||
m.SetDownloadOnly()
|
||||
|
||||
ver, enforced := waitForUpdateEvent(sub, 500*time.Millisecond)
|
||||
if ver == "" {
|
||||
t.Fatal("Mode 1: expected new_version_available event")
|
||||
}
|
||||
if enforced {
|
||||
t.Error("Mode 1: expected no enforced metadata")
|
||||
}
|
||||
m.Stop()
|
||||
|
||||
// Mode 2 (enforced, forceUpdate=false): enforced metadata present, no auto-install
|
||||
tmpFile2 := path.Join(t.TempDir(), "update-test-mode2.json")
|
||||
recorder2 := peer.NewRecorder("")
|
||||
sub2 := recorder2.SubscribeToEvents()
|
||||
defer recorder2.UnsubscribeFromEvents(sub2)
|
||||
|
||||
m2 := NewManager(recorder2, statemanager.New(tmpFile2))
|
||||
m2.update = &versionUpdateMock{latestVersion: nil}
|
||||
m2.currentVersion = "1.0.0"
|
||||
m2.autoUpdateSupported = func() bool { return true }
|
||||
m2.Start(context.Background())
|
||||
m2.SetVersion("1.0.1", false)
|
||||
|
||||
ver, enforced2 := waitForUpdateEvent(sub2, 500*time.Millisecond)
|
||||
if ver == "" {
|
||||
t.Fatal("Mode 2: expected new_version_available event")
|
||||
}
|
||||
if !enforced2 {
|
||||
t.Error("Mode 2: expected enforced metadata")
|
||||
}
|
||||
m2.Stop()
|
||||
}
|
||||
|
||||
// ensure the proto import is used
|
||||
var _ = cProto.SystemEvent_INFO
|
||||
56
client/internal/updater/manager_test_helpers_test.go
Normal file
56
client/internal/updater/manager_test_helpers_test.go
Normal file
@@ -0,0 +1,56 @@
|
||||
package updater
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"time"
|
||||
|
||||
v "github.com/hashicorp/go-version"
|
||||
|
||||
"github.com/netbirdio/netbird/client/internal/peer"
|
||||
)
|
||||
|
||||
type versionUpdateMock struct {
|
||||
latestVersion *v.Version
|
||||
onUpdate func()
|
||||
}
|
||||
|
||||
func (m versionUpdateMock) StopWatch() {}
|
||||
|
||||
func (m versionUpdateMock) SetDaemonVersion(newVersion string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (m *versionUpdateMock) SetOnUpdateListener(updateFn func()) {
|
||||
m.onUpdate = updateFn
|
||||
}
|
||||
|
||||
func (m versionUpdateMock) LatestVersion() *v.Version {
|
||||
return m.latestVersion
|
||||
}
|
||||
|
||||
func (m versionUpdateMock) StartFetcher() {}
|
||||
|
||||
// waitForUpdateEvent waits for a new_version_available event, returns the version string or "" on timeout.
|
||||
func waitForUpdateEvent(sub *peer.EventSubscription, timeout time.Duration) (version string, enforced bool) {
|
||||
timer := time.NewTimer(timeout)
|
||||
defer timer.Stop()
|
||||
for {
|
||||
select {
|
||||
case event, ok := <-sub.Events():
|
||||
if !ok {
|
||||
return "", false
|
||||
}
|
||||
if val, ok := event.Metadata["new_version_available"]; ok {
|
||||
enforced := false
|
||||
if raw, ok := event.Metadata["enforced"]; ok {
|
||||
if parsed, err := strconv.ParseBool(raw); err == nil {
|
||||
enforced = parsed
|
||||
}
|
||||
}
|
||||
return val, enforced
|
||||
}
|
||||
case <-timer.C:
|
||||
return "", false
|
||||
}
|
||||
}
|
||||
}
|
||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user