mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-27 12:46:39 +00:00
Compare commits
178 Commits
trigger-pr
...
fix/gettin
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
e32ad68f98 | ||
|
|
9d1a37c644 | ||
|
|
5bf2372c4d | ||
|
|
c2c6396a04 | ||
|
|
aaf813fc0c | ||
|
|
d97fe84296 | ||
|
|
81f45dab21 | ||
|
|
d670e7382a | ||
|
|
cd8c686339 | ||
|
|
f5c41e3018 | ||
|
|
2477f99d89 | ||
|
|
940f530ac2 | ||
|
|
4d3e2f8ad3 | ||
|
|
5ae986e1c4 | ||
|
|
e5914e4e8b | ||
|
|
c238f5425f | ||
|
|
3c3097ea74 | ||
|
|
405c3f4003 | ||
|
|
6553ce4cea | ||
|
|
a62d472bc4 | ||
|
|
434ac7f0f5 | ||
|
|
7bbe71c3ac | ||
|
|
04dcaadabf | ||
|
|
c522506849 | ||
|
|
0765352c99 | ||
|
|
13807f1b3d | ||
|
|
c919ea149e | ||
|
|
be6fd119d8 | ||
|
|
7abf730d77 | ||
|
|
ec96c5ecaf | ||
|
|
7e1cce4b9f | ||
|
|
7be8752a00 | ||
|
|
145d82f322 | ||
|
|
a8b9570700 | ||
|
|
6ff6d84646 | ||
|
|
9aaa05e8ea | ||
|
|
0af5a0441f | ||
|
|
0fc63ea0ba | ||
|
|
0b329f7881 | ||
|
|
5b85edb753 | ||
|
|
17cfa5fe1e | ||
|
|
2313494e0e | ||
|
|
fd9d430334 | ||
|
|
91f0d5cefd | ||
|
|
82762280ee | ||
|
|
b550a2face | ||
|
|
ab77508950 | ||
|
|
b9462f5c6b | ||
|
|
5ffaa5cdd6 | ||
|
|
a1858a9cb7 | ||
|
|
212b34f639 | ||
|
|
af8eaa23e2 | ||
|
|
f0eed50678 | ||
|
|
19d94c6158 | ||
|
|
628eb56073 | ||
|
|
a590c38d8b | ||
|
|
4e149c9222 | ||
|
|
59f5b34280 | ||
|
|
dff06d0898 | ||
|
|
80a8816b1d | ||
|
|
387e374e4b | ||
|
|
3e6baea405 | ||
|
|
fe9b844511 | ||
|
|
2e1aa497d2 | ||
|
|
529c0314f8 | ||
|
|
d86875aeac | ||
|
|
f80fe506d5 | ||
|
|
967c6f3cd3 | ||
|
|
e50e124e70 | ||
|
|
c545689448 | ||
|
|
8f389fef19 | ||
|
|
d3d6a327e0 | ||
|
|
b5489d4986 | ||
|
|
7a23c57cf8 | ||
|
|
11f891220e | ||
|
|
5585adce18 | ||
|
|
f884299823 | ||
|
|
15aa6bae1b | ||
|
|
11eb725ac8 | ||
|
|
30c02ab78c | ||
|
|
3acd86e346 | ||
|
|
5c20f13c48 | ||
|
|
e6587b071d | ||
|
|
85451ab4cd | ||
|
|
a7f3ba03eb | ||
|
|
4f0a3a77ad | ||
|
|
44655ca9b5 | ||
|
|
e601278117 | ||
|
|
8e7b016be2 | ||
|
|
9e01ea7aae | ||
|
|
cfc7ec8bb9 | ||
|
|
b3bbc0e5c6 | ||
|
|
d7c8e37ff4 | ||
|
|
05b66e73bc | ||
|
|
01ceedac89 | ||
|
|
403babd433 | ||
|
|
47133031e5 | ||
|
|
82da606886 | ||
|
|
bbe5ae2145 | ||
|
|
0b21498b39 | ||
|
|
0ca59535f1 | ||
|
|
59c77d0658 | ||
|
|
333e045099 | ||
|
|
c2c4d9d336 | ||
|
|
9a6a72e88e | ||
|
|
afe6d9fca4 | ||
|
|
ef82905526 | ||
|
|
d18747e846 | ||
|
|
f341d69314 | ||
|
|
327142837c | ||
|
|
f8c0321aee | ||
|
|
89115ff76a | ||
|
|
63c83aa8d2 | ||
|
|
37f025c966 | ||
|
|
4a54f0d670 | ||
|
|
98890a29e3 | ||
|
|
9d123ec059 | ||
|
|
5d171f181a | ||
|
|
22f878b3b7 | ||
|
|
44ef1a18dd | ||
|
|
2b98dc4e52 | ||
|
|
2a26cb4567 | ||
|
|
5ca1b64328 | ||
|
|
36752a8cbb | ||
|
|
f117fc7509 | ||
|
|
fc6b93ae59 | ||
|
|
564fa4ab04 | ||
|
|
a6db88fbd2 | ||
|
|
4b5294e596 | ||
|
|
a322dce42a | ||
|
|
d1ead2265b | ||
|
|
bbca74476e | ||
|
|
318cf59d66 | ||
|
|
e9b2a6e808 | ||
|
|
2dbdb5c1a7 | ||
|
|
2cdab6d7b7 | ||
|
|
e49c0e8862 | ||
|
|
e7c84d0ead | ||
|
|
1c934cca64 | ||
|
|
4aff4a6424 | ||
|
|
1bd7190954 | ||
|
|
0146e39714 | ||
|
|
baed6e46ec | ||
|
|
0d1ffba75f | ||
|
|
1024d45698 | ||
|
|
e5d4947d60 | ||
|
|
cb9b39b950 | ||
|
|
68c481fa44 | ||
|
|
01a9cd4651 | ||
|
|
f53155562f | ||
|
|
edce11b34d | ||
|
|
841b2d26c6 | ||
|
|
d3eeb6d8ee | ||
|
|
7ebf37ef20 | ||
|
|
64b849c801 | ||
|
|
69d4b5d821 | ||
|
|
3dfa97dcbd | ||
|
|
1ddc9ce2bf | ||
|
|
2de1949018 | ||
|
|
fc88399c23 | ||
|
|
6981fdce7e | ||
|
|
08403f64aa | ||
|
|
391221a986 | ||
|
|
7bc85107eb | ||
|
|
3be16d19a0 | ||
|
|
af8f730bda | ||
|
|
c3f176f348 | ||
|
|
0119f3e9f4 | ||
|
|
1b96648d4d | ||
|
|
d2f9653cea | ||
|
|
194a986926 | ||
|
|
f7732557fa | ||
|
|
d488f58311 | ||
|
|
6fdc00ff41 | ||
|
|
b20d484972 | ||
|
|
8931293343 | ||
|
|
7b830d8f72 | ||
|
|
3a0cf230a1 |
6
.dockerignore
Normal file
6
.dockerignore
Normal file
@@ -0,0 +1,6 @@
|
|||||||
|
.env
|
||||||
|
.env.*
|
||||||
|
*.pem
|
||||||
|
*.key
|
||||||
|
*.crt
|
||||||
|
*.p12
|
||||||
14
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
14
.github/ISSUE_TEMPLATE/config.yml
vendored
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
blank_issues_enabled: true
|
||||||
|
contact_links:
|
||||||
|
- name: Community Support
|
||||||
|
url: https://forum.netbird.io/
|
||||||
|
about: Community support forum
|
||||||
|
- name: Cloud Support
|
||||||
|
url: https://docs.netbird.io/help/report-bug-issues
|
||||||
|
about: Contact us for support
|
||||||
|
- name: Client/Connection Troubleshooting
|
||||||
|
url: https://docs.netbird.io/help/troubleshooting-client
|
||||||
|
about: See our client troubleshooting guide for help addressing common issues
|
||||||
|
- name: Self-host Troubleshooting
|
||||||
|
url: https://docs.netbird.io/selfhosted/troubleshooting
|
||||||
|
about: See our self-host troubleshooting guide for help addressing common issues
|
||||||
10
.github/workflows/check-license-dependencies.yml
vendored
10
.github/workflows/check-license-dependencies.yml
vendored
@@ -23,7 +23,7 @@ jobs:
|
|||||||
|
|
||||||
- name: Check for problematic license dependencies
|
- name: Check for problematic license dependencies
|
||||||
run: |
|
run: |
|
||||||
echo "Checking for dependencies on management/, signal/, and relay/ packages..."
|
echo "Checking for dependencies on management/, signal/, relay/, and proxy/ packages..."
|
||||||
echo ""
|
echo ""
|
||||||
|
|
||||||
# Find all directories except the problematic ones and system dirs
|
# Find all directories except the problematic ones and system dirs
|
||||||
@@ -31,7 +31,7 @@ jobs:
|
|||||||
while IFS= read -r dir; do
|
while IFS= read -r dir; do
|
||||||
echo "=== Checking $dir ==="
|
echo "=== Checking $dir ==="
|
||||||
# Search for problematic imports, excluding test files
|
# Search for problematic imports, excluding test files
|
||||||
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" || true)
|
RESULTS=$(grep -r "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\)" "$dir" --include="*.go" 2>/dev/null | grep -v "_test.go" | grep -v "test_" | grep -v "/test/" | grep -v "tools/idp-migrate/" || true)
|
||||||
if [ -n "$RESULTS" ]; then
|
if [ -n "$RESULTS" ]; then
|
||||||
echo "❌ Found problematic dependencies:"
|
echo "❌ Found problematic dependencies:"
|
||||||
echo "$RESULTS"
|
echo "$RESULTS"
|
||||||
@@ -39,11 +39,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "✓ No problematic dependencies found"
|
echo "✓ No problematic dependencies found"
|
||||||
fi
|
fi
|
||||||
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name ".git*" | sort)
|
done < <(find . -maxdepth 1 -type d -not -name "." -not -name "management" -not -name "signal" -not -name "relay" -not -name "proxy" -not -name "combined" -not -name ".git*" | sort)
|
||||||
|
|
||||||
echo ""
|
echo ""
|
||||||
if [ $FOUND_ISSUES -eq 1 ]; then
|
if [ $FOUND_ISSUES -eq 1 ]; then
|
||||||
echo "❌ Found dependencies on management/, signal/, or relay/ packages"
|
echo "❌ Found dependencies on management/, signal/, relay/, or proxy/ packages"
|
||||||
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
echo "These packages are licensed under AGPLv3 and must not be imported by BSD-licensed code"
|
||||||
exit 1
|
exit 1
|
||||||
else
|
else
|
||||||
@@ -88,7 +88,7 @@ jobs:
|
|||||||
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
IMPORTERS=$(go list -json -deps ./... 2>/dev/null | jq -r "select(.Imports[]? == \"$package\") | .ImportPath")
|
||||||
|
|
||||||
# Check if any importer is NOT in management/signal/relay
|
# Check if any importer is NOT in management/signal/relay
|
||||||
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\)" | head -1)
|
BSD_IMPORTER=$(echo "$IMPORTERS" | grep -v "github.com/netbirdio/netbird/\(management\|signal\|relay\|proxy\|combined\|tools/idp-migrate\)" | head -1)
|
||||||
|
|
||||||
if [ -n "$BSD_IMPORTER" ]; then
|
if [ -n "$BSD_IMPORTER" ]; then
|
||||||
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
echo "❌ $package ($license) is imported by BSD-licensed code: $BSD_IMPORTER"
|
||||||
|
|||||||
2
.github/workflows/golang-test-darwin.yml
vendored
2
.github/workflows/golang-test-darwin.yml
vendored
@@ -43,5 +43,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
|
|||||||
1
.github/workflows/golang-test-freebsd.yml
vendored
1
.github/workflows/golang-test-freebsd.yml
vendored
@@ -46,6 +46,5 @@ jobs:
|
|||||||
time go test -timeout 1m -failfast ./client/iface/...
|
time go test -timeout 1m -failfast ./client/iface/...
|
||||||
time go test -timeout 1m -failfast ./route/...
|
time go test -timeout 1m -failfast ./route/...
|
||||||
time go test -timeout 1m -failfast ./sharedsock/...
|
time go test -timeout 1m -failfast ./sharedsock/...
|
||||||
time go test -timeout 1m -failfast ./signal/...
|
|
||||||
time go test -timeout 1m -failfast ./util/...
|
time go test -timeout 1m -failfast ./util/...
|
||||||
time go test -timeout 1m -failfast ./version/...
|
time go test -timeout 1m -failfast ./version/...
|
||||||
|
|||||||
98
.github/workflows/golang-test-linux.yml
vendored
98
.github/workflows/golang-test-linux.yml
vendored
@@ -97,6 +97,16 @@ jobs:
|
|||||||
working-directory: relay
|
working-directory: relay
|
||||||
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||||
|
|
||||||
|
- name: Build combined
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: combined
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build combined 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: combined
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o combined-386 .
|
||||||
|
|
||||||
test:
|
test:
|
||||||
name: "Client / Unit"
|
name: "Client / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -144,7 +154,7 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined)
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
@@ -204,7 +214,7 @@ jobs:
|
|||||||
sh -c ' \
|
sh -c ' \
|
||||||
apk update; apk add --no-cache \
|
apk update; apk add --no-cache \
|
||||||
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /proxy -e /combined -e /client/ui -e /upload-server)
|
||||||
'
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
@@ -261,6 +271,53 @@ jobs:
|
|||||||
-exec 'sudo' \
|
-exec 'sudo' \
|
||||||
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
-timeout 10m -p 1 ./relay/... ./shared/relay/...
|
||||||
|
|
||||||
|
test_proxy:
|
||||||
|
name: "Proxy / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version-file: "go.mod"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test -timeout 10m -p 1 ./proxy/...
|
||||||
|
|
||||||
test_signal:
|
test_signal:
|
||||||
name: "Signal / Unit"
|
name: "Signal / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -352,12 +409,19 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: docker login for root user
|
||||||
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
|
env:
|
||||||
|
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||||
|
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||||
|
|
||||||
- name: download mysql image
|
- name: download mysql image
|
||||||
if: matrix.store == 'mysql'
|
if: matrix.store == 'mysql'
|
||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
@@ -440,15 +504,18 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
- name: download mysql image
|
- name: docker login for root user
|
||||||
if: matrix.store == 'mysql'
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
env:
|
||||||
|
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||||
|
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
@@ -529,15 +596,18 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Login to Docker hub
|
- name: Login to Docker hub
|
||||||
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
uses: docker/login-action@v1
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
- name: download mysql image
|
- name: docker login for root user
|
||||||
if: matrix.store == 'mysql'
|
if: github.event.pull_request && github.event.pull_request.head.repo && github.event.pull_request.head.repo.full_name == '' || github.repository == github.event.pull_request.head.repo.full_name || !github.head_ref
|
||||||
run: docker pull mlsmaycon/warmed-mysql:8
|
env:
|
||||||
|
DOCKER_USER: ${{ secrets.DOCKER_USER }}
|
||||||
|
DOCKER_TOKEN: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
run: echo "$DOCKER_TOKEN" | sudo docker login --username "$DOCKER_USER" --password-stdin
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
|
|||||||
9
.github/workflows/golang-test-windows.yml
vendored
9
.github/workflows/golang-test-windows.yml
vendored
@@ -63,10 +63,15 @@ jobs:
|
|||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' })" >> $env:GITHUB_ENV
|
- name: Generate test script
|
||||||
|
run: |
|
||||||
|
$packages = go list ./... | Where-Object { $_ -notmatch '/management' } | Where-Object { $_ -notmatch '/relay' } | Where-Object { $_ -notmatch '/signal' } | Where-Object { $_ -notmatch '/proxy' } | Where-Object { $_ -notmatch '/combined' }
|
||||||
|
$goExe = "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe"
|
||||||
|
$cmd = "$goExe test -tags=devcert -timeout 10m -p 1 $($packages -join ' ') > test-out.txt 2>&1"
|
||||||
|
Set-Content -Path "${{ github.workspace }}\run-tests.cmd" -Value $cmd
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "${{ github.workspace }}\run-tests.cmd"
|
||||||
- name: test output
|
- name: test output
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: Get-Content test-out.txt
|
run: Get-Content test-out.txt
|
||||||
|
|||||||
4
.github/workflows/golangci-lint.yml
vendored
4
.github/workflows/golangci-lint.yml
vendored
@@ -19,8 +19,8 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe,cros,ans,deriver,te,userA
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum,**/proxy/web/**
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
51
.github/workflows/pr-title-check.yml
vendored
Normal file
51
.github/workflows/pr-title-check.yml
vendored
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
name: PR Title Check
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
types: [opened, edited, synchronize, reopened]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
check-title:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Validate PR title prefix
|
||||||
|
uses: actions/github-script@v7
|
||||||
|
with:
|
||||||
|
script: |
|
||||||
|
const title = context.payload.pull_request.title;
|
||||||
|
const allowedTags = [
|
||||||
|
'management',
|
||||||
|
'client',
|
||||||
|
'signal',
|
||||||
|
'proxy',
|
||||||
|
'relay',
|
||||||
|
'misc',
|
||||||
|
'infrastructure',
|
||||||
|
'self-hosted',
|
||||||
|
'doc',
|
||||||
|
];
|
||||||
|
|
||||||
|
const pattern = /^\[([^\]]+)\]\s+.+/;
|
||||||
|
const match = title.match(pattern);
|
||||||
|
|
||||||
|
if (!match) {
|
||||||
|
core.setFailed(
|
||||||
|
`PR title must start with a tag in brackets.\n` +
|
||||||
|
`Example: [client] fix something\n` +
|
||||||
|
`Allowed tags: ${allowedTags.join(', ')}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
const tags = match[1].split(',').map(t => t.trim().toLowerCase());
|
||||||
|
|
||||||
|
const invalid = tags.filter(t => !allowedTags.includes(t));
|
||||||
|
if (invalid.length > 0) {
|
||||||
|
core.setFailed(
|
||||||
|
`Invalid tag(s): ${invalid.join(', ')}\n` +
|
||||||
|
`Allowed tags: ${allowedTags.join(', ')}`
|
||||||
|
);
|
||||||
|
return;
|
||||||
|
}
|
||||||
|
|
||||||
|
console.log(`Valid PR title tags: [${tags.join(', ')}]`);
|
||||||
90
.github/workflows/release.yml
vendored
90
.github/workflows/release.yml
vendored
@@ -9,8 +9,8 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.1.0"
|
SIGN_PIPE_VER: "v0.1.1"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.14.3"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|
||||||
@@ -160,7 +160,7 @@ jobs:
|
|||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
- name: Log in to the GitHub container registry
|
- name: Log in to the GitHub container registry
|
||||||
if: github.event_name != 'pull_request'
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||||
uses: docker/login-action@v3
|
uses: docker/login-action@v3
|
||||||
with:
|
with:
|
||||||
registry: ghcr.io
|
registry: ghcr.io
|
||||||
@@ -169,6 +169,14 @@ jobs:
|
|||||||
- name: Install OS build dependencies
|
- name: Install OS build dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||||
|
|
||||||
|
- name: Decode GPG signing key
|
||||||
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
env:
|
||||||
|
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||||
|
run: |
|
||||||
|
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
|
||||||
|
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Install goversioninfo
|
- name: Install goversioninfo
|
||||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||||
- name: Generate windows syso amd64
|
- name: Generate windows syso amd64
|
||||||
@@ -176,6 +184,7 @@ jobs:
|
|||||||
- name: Generate windows syso arm64
|
- name: Generate windows syso arm64
|
||||||
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
run: goversioninfo -arm -64 -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_arm64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
|
id: goreleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
version: ${{ env.GORELEASER_VER }}
|
version: ${{ env.GORELEASER_VER }}
|
||||||
@@ -185,6 +194,55 @@ jobs:
|
|||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
|
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||||
|
NFPM_NETBIRD_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||||
|
- name: Verify RPM signatures
|
||||||
|
run: |
|
||||||
|
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||||
|
dnf install -y -q rpm-sign curl >/dev/null 2>&1
|
||||||
|
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
|
||||||
|
rpm --import /tmp/rpm-pub.key
|
||||||
|
echo "=== Verifying RPM signatures ==="
|
||||||
|
for rpm_file in /dist/*amd64*.rpm; do
|
||||||
|
[ -f "$rpm_file" ] || continue
|
||||||
|
echo "--- $(basename $rpm_file) ---"
|
||||||
|
rpm -K "$rpm_file"
|
||||||
|
done
|
||||||
|
'
|
||||||
|
- name: Clean up GPG key
|
||||||
|
if: always()
|
||||||
|
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||||
|
- name: Tag and push images (amd64 only)
|
||||||
|
if: |
|
||||||
|
(github.event_name == 'pull_request' && github.event.pull_request.head.repo.full_name == github.repository) ||
|
||||||
|
(github.event_name == 'push' && github.ref == 'refs/heads/main')
|
||||||
|
run: |
|
||||||
|
resolve_tags() {
|
||||||
|
if [[ "${{ github.event_name }}" == "pull_request" ]]; then
|
||||||
|
echo "pr-${{ github.event.pull_request.number }}"
|
||||||
|
else
|
||||||
|
echo "main sha-$(git rev-parse --short HEAD)"
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
tag_and_push() {
|
||||||
|
local src="$1" img_name tag dst
|
||||||
|
img_name="${src%%:*}"
|
||||||
|
for tag in $(resolve_tags); do
|
||||||
|
dst="${img_name}:${tag}"
|
||||||
|
echo "Tagging ${src} -> ${dst}"
|
||||||
|
docker tag "$src" "$dst"
|
||||||
|
docker push "$dst"
|
||||||
|
done
|
||||||
|
}
|
||||||
|
|
||||||
|
export -f tag_and_push resolve_tags
|
||||||
|
|
||||||
|
echo '${{ steps.goreleaser.outputs.artifacts }}' | \
|
||||||
|
jq -r '.[] | select(.type == "Docker Image") | select(.goarch == "amd64") | .name' | \
|
||||||
|
grep '^ghcr.io/' | while read -r SRC; do
|
||||||
|
tag_and_push "$SRC"
|
||||||
|
done
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
@@ -251,6 +309,14 @@ jobs:
|
|||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
run: sudo apt update && sudo apt install -y -q libappindicator3-dev gir1.2-appindicator3-0.1 libxxf86vm-dev gcc-mingw-w64-x86-64
|
||||||
|
|
||||||
|
- name: Decode GPG signing key
|
||||||
|
if: github.event_name != 'pull_request' || github.event.pull_request.head.repo.full_name == github.repository
|
||||||
|
env:
|
||||||
|
GPG_RPM_PRIVATE_KEY: ${{ secrets.GPG_RPM_PRIVATE_KEY }}
|
||||||
|
run: |
|
||||||
|
echo "$GPG_RPM_PRIVATE_KEY" | base64 -d > /tmp/gpg-rpm-signing-key.asc
|
||||||
|
echo "GPG_RPM_KEY_FILE=/tmp/gpg-rpm-signing-key.asc" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Install LLVM-MinGW for ARM64 cross-compilation
|
- name: Install LLVM-MinGW for ARM64 cross-compilation
|
||||||
run: |
|
run: |
|
||||||
cd /tmp
|
cd /tmp
|
||||||
@@ -275,6 +341,24 @@ jobs:
|
|||||||
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
HOMEBREW_TAP_GITHUB_TOKEN: ${{ secrets.HOMEBREW_TAP_GITHUB_TOKEN }}
|
||||||
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_DEBIAN_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
UPLOAD_YUM_SECRET: ${{ secrets.PKG_UPLOAD_SECRET }}
|
||||||
|
GPG_RPM_KEY_FILE: ${{ env.GPG_RPM_KEY_FILE }}
|
||||||
|
NFPM_NETBIRD_UI_RPM_PASSPHRASE: ${{ secrets.GPG_RPM_PASSPHRASE }}
|
||||||
|
- name: Verify RPM signatures
|
||||||
|
run: |
|
||||||
|
docker run --rm -v $(pwd)/dist:/dist fedora:41 bash -c '
|
||||||
|
dnf install -y -q rpm-sign curl >/dev/null 2>&1
|
||||||
|
curl -sSL https://pkgs.netbird.io/yum/repodata/repomd.xml.key -o /tmp/rpm-pub.key
|
||||||
|
rpm --import /tmp/rpm-pub.key
|
||||||
|
echo "=== Verifying RPM signatures ==="
|
||||||
|
for rpm_file in /dist/*.rpm; do
|
||||||
|
[ -f "$rpm_file" ] || continue
|
||||||
|
echo "--- $(basename $rpm_file) ---"
|
||||||
|
rpm -K "$rpm_file"
|
||||||
|
done
|
||||||
|
'
|
||||||
|
- name: Clean up GPG key
|
||||||
|
if: always()
|
||||||
|
run: rm -f /tmp/gpg-rpm-signing-key.asc
|
||||||
- name: upload non tags for debug purposes
|
- name: upload non tags for debug purposes
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
|
|||||||
4
.github/workflows/wasm-build-validation.yml
vendored
4
.github/workflows/wasm-build-validation.yml
vendored
@@ -61,8 +61,8 @@ jobs:
|
|||||||
|
|
||||||
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
echo "Size: ${SIZE} bytes (${SIZE_MB} MB)"
|
||||||
|
|
||||||
if [ ${SIZE} -gt 57671680 ]; then
|
if [ ${SIZE} -gt 58720256 ]; then
|
||||||
echo "Wasm binary size (${SIZE_MB}MB) exceeds 55MB limit!"
|
echo "Wasm binary size (${SIZE_MB}MB) exceeds 56MB limit!"
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -2,6 +2,7 @@
|
|||||||
.run
|
.run
|
||||||
*.iml
|
*.iml
|
||||||
dist/
|
dist/
|
||||||
|
!proxy/web/dist/
|
||||||
bin/
|
bin/
|
||||||
.env
|
.env
|
||||||
conf.json
|
conf.json
|
||||||
|
|||||||
220
.goreleaser.yaml
220
.goreleaser.yaml
@@ -106,6 +106,26 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-server
|
||||||
|
dir: combined
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=1
|
||||||
|
- >-
|
||||||
|
{{- if eq .Runtime.Goos "linux" }}
|
||||||
|
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||||
|
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
binary: netbird-server
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
- id: netbird-upload
|
- id: netbird-upload
|
||||||
dir: upload-server
|
dir: upload-server
|
||||||
env: [CGO_ENABLED=0]
|
env: [CGO_ENABLED=0]
|
||||||
@@ -120,6 +140,40 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-proxy
|
||||||
|
dir: proxy/cmd/proxy
|
||||||
|
env: [CGO_ENABLED=0]
|
||||||
|
binary: netbird-proxy
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X main.Version={{.Version}} -X main.Commit={{.Commit}} -X main.BuildDate={{.CommitDate}}
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-idp-migrate
|
||||||
|
dir: tools/idp-migrate
|
||||||
|
env:
|
||||||
|
- CGO_ENABLED=1
|
||||||
|
- >-
|
||||||
|
{{- if eq .Runtime.Goos "linux" }}
|
||||||
|
{{- if eq .Arch "arm64"}}CC=aarch64-linux-gnu-gcc{{- end }}
|
||||||
|
{{- if eq .Arch "arm"}}CC=arm-linux-gnueabihf-gcc{{- end }}
|
||||||
|
{{- end }}
|
||||||
|
binary: netbird-idp-migrate
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
universal_binaries:
|
universal_binaries:
|
||||||
- id: netbird
|
- id: netbird
|
||||||
|
|
||||||
@@ -132,18 +186,22 @@ archives:
|
|||||||
- netbird-wasm
|
- netbird-wasm
|
||||||
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
name_template: "{{ .ProjectName }}_{{ .Version }}"
|
||||||
format: binary
|
format: binary
|
||||||
|
- id: netbird-idp-migrate
|
||||||
|
builds:
|
||||||
|
- netbird-idp-migrate
|
||||||
|
name_template: "netbird-idp-migrate_{{ .Version }}_{{ .Os }}_{{ .Arch }}"
|
||||||
|
|
||||||
nfpms:
|
nfpms:
|
||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client.
|
description: Netbird client.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
id: netbird-deb
|
license: BSD-3-Clause
|
||||||
|
id: netbird_deb
|
||||||
bindir: /usr/bin
|
bindir: /usr/bin
|
||||||
builds:
|
builds:
|
||||||
- netbird
|
- netbird
|
||||||
formats:
|
formats:
|
||||||
- deb
|
- deb
|
||||||
|
|
||||||
scripts:
|
scripts:
|
||||||
postinstall: "release_files/post_install.sh"
|
postinstall: "release_files/post_install.sh"
|
||||||
preremove: "release_files/pre_remove.sh"
|
preremove: "release_files/pre_remove.sh"
|
||||||
@@ -151,16 +209,19 @@ nfpms:
|
|||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client.
|
description: Netbird client.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
id: netbird-rpm
|
license: BSD-3-Clause
|
||||||
|
id: netbird_rpm
|
||||||
bindir: /usr/bin
|
bindir: /usr/bin
|
||||||
builds:
|
builds:
|
||||||
- netbird
|
- netbird
|
||||||
formats:
|
formats:
|
||||||
- rpm
|
- rpm
|
||||||
|
|
||||||
scripts:
|
scripts:
|
||||||
postinstall: "release_files/post_install.sh"
|
postinstall: "release_files/post_install.sh"
|
||||||
preremove: "release_files/pre_remove.sh"
|
preremove: "release_files/pre_remove.sh"
|
||||||
|
rpm:
|
||||||
|
signature:
|
||||||
|
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||||
dockers:
|
dockers:
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-amd64
|
- netbirdio/netbird:{{ .Version }}-amd64
|
||||||
@@ -520,6 +581,104 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
ids:
|
||||||
|
- netbird-server
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: combined/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird-server
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: combined/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
ids:
|
||||||
|
- netbird-server
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: combined/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
ids:
|
||||||
|
- netbird-proxy
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: proxy/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
docker_manifests:
|
docker_manifests:
|
||||||
- name_template: netbirdio/netbird:{{ .Version }}
|
- name_template: netbirdio/netbird:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
@@ -598,6 +757,18 @@ docker_manifests:
|
|||||||
- netbirdio/upload:{{ .Version }}-arm
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
- netbirdio/upload:{{ .Version }}-amd64
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird-server:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird-server:latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
@@ -675,6 +846,43 @@ docker_manifests:
|
|||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||||
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird-server:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird-server:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird-server:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/reverse-proxy:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/reverse-proxy:latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/reverse-proxy:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/reverse-proxy:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/reverse-proxy:{{ .Version }}-amd64
|
||||||
|
|
||||||
brews:
|
brews:
|
||||||
- ids:
|
- ids:
|
||||||
- default
|
- default
|
||||||
@@ -695,7 +903,7 @@ brews:
|
|||||||
uploads:
|
uploads:
|
||||||
- name: debian
|
- name: debian
|
||||||
ids:
|
ids:
|
||||||
- netbird-deb
|
- netbird_deb
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
@@ -703,7 +911,7 @@ uploads:
|
|||||||
|
|
||||||
- name: yum
|
- name: yum
|
||||||
ids:
|
ids:
|
||||||
- netbird-rpm
|
- netbird_rpm
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ nfpms:
|
|||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client UI.
|
description: Netbird client UI.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
id: netbird-ui-deb
|
id: netbird_ui_deb
|
||||||
package_name: netbird-ui
|
package_name: netbird-ui
|
||||||
builds:
|
builds:
|
||||||
- netbird-ui
|
- netbird-ui
|
||||||
@@ -80,7 +80,7 @@ nfpms:
|
|||||||
- maintainer: Netbird <dev@netbird.io>
|
- maintainer: Netbird <dev@netbird.io>
|
||||||
description: Netbird client UI.
|
description: Netbird client UI.
|
||||||
homepage: https://netbird.io/
|
homepage: https://netbird.io/
|
||||||
id: netbird-ui-rpm
|
id: netbird_ui_rpm
|
||||||
package_name: netbird-ui
|
package_name: netbird-ui
|
||||||
builds:
|
builds:
|
||||||
- netbird-ui
|
- netbird-ui
|
||||||
@@ -95,11 +95,14 @@ nfpms:
|
|||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
|
rpm:
|
||||||
|
signature:
|
||||||
|
key_file: '{{ if index .Env "GPG_RPM_KEY_FILE" }}{{ .Env.GPG_RPM_KEY_FILE }}{{ end }}'
|
||||||
|
|
||||||
uploads:
|
uploads:
|
||||||
- name: debian
|
- name: debian
|
||||||
ids:
|
ids:
|
||||||
- netbird-ui-deb
|
- netbird_ui_deb
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
target: https://pkgs.wiretrustee.com/debian/pool/{{ .ArtifactName }};deb.distribution=stable;deb.component=main;deb.architecture={{ if .Arm }}armhf{{ else }}{{ .Arch }}{{ end }};deb.package=
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
@@ -107,7 +110,7 @@ uploads:
|
|||||||
|
|
||||||
- name: yum
|
- name: yum
|
||||||
ids:
|
ids:
|
||||||
- netbird-ui-rpm
|
- netbird_ui_rpm
|
||||||
mode: archive
|
mode: archive
|
||||||
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
target: https://pkgs.wiretrustee.com/yum/{{ .Arch }}{{ if .Arm }}{{ .Arm }}{{ end }}
|
||||||
username: dev@wiretrustee.com
|
username: dev@wiretrustee.com
|
||||||
|
|||||||
@@ -1,7 +1,7 @@
|
|||||||
## Contributor License Agreement
|
## Contributor License Agreement
|
||||||
|
|
||||||
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
||||||
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
|
submitting this Agreement and NetBird GmbH, Brunnenstraße 196, 10119 Berlin, Germany,
|
||||||
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
||||||
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
||||||
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
||||||
|
|||||||
2
LICENSE
2
LICENSE
@@ -1,4 +1,4 @@
|
|||||||
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/ and relay/.
|
This BSD‑3‑Clause license applies to all parts of the repository except for the directories management/, signal/, relay/ and combined/.
|
||||||
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
Those directories are licensed under the GNU Affero General Public License version 3.0 (AGPLv3). See the respective LICENSE files inside each directory.
|
||||||
|
|
||||||
BSD 3-Clause License
|
BSD 3-Clause License
|
||||||
|
|||||||
@@ -60,8 +60,8 @@
|
|||||||
|
|
||||||
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
https://github.com/user-attachments/assets/10cec749-bb56-4ab3-97af-4e38850108d2
|
||||||
|
|
||||||
### NetBird on Lawrence Systems (Video)
|
### Self-Host NetBird (Video)
|
||||||
[](https://www.youtube.com/watch?v=Kwrff6h0rEw)
|
[](https://youtu.be/bZAgpT6nzaQ)
|
||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
@@ -126,6 +126,7 @@ See a complete [architecture overview](https://docs.netbird.io/about-netbird/how
|
|||||||
### Community projects
|
### Community projects
|
||||||
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
- [NetBird installer script](https://github.com/physk/netbird-installer)
|
||||||
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
- [NetBird ansible collection by Dominion Solutions](https://galaxy.ansible.com/ui/repo/published/dominion_solutions/netbird/)
|
||||||
|
- [netbird-tui](https://github.com/n0pashkov/netbird-tui) — terminal UI for managing NetBird peers, routes, and settings
|
||||||
|
|
||||||
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
**Note**: The `main` branch may be in an *unstable or even broken state* during development.
|
||||||
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
For stable versions, see [releases](https://github.com/netbirdio/netbird/releases).
|
||||||
|
|||||||
@@ -4,7 +4,7 @@
|
|||||||
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
# sudo podman build -t localhost/netbird:latest -f client/Dockerfile --ignorefile .dockerignore-client .
|
||||||
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
# sudo podman run --rm -it --cap-add={BPF,NET_ADMIN,NET_RAW} localhost/netbird:latest
|
||||||
|
|
||||||
FROM alpine:3.23.2
|
FROM alpine:3.23.3
|
||||||
# iproute2: busybox doesn't display ip rules properly
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
RUN apk add --no-cache \
|
RUN apk add --no-cache \
|
||||||
bash \
|
bash \
|
||||||
@@ -17,8 +17,7 @@ ENV \
|
|||||||
NETBIRD_BIN="/usr/local/bin/netbird" \
|
NETBIRD_BIN="/usr/local/bin/netbird" \
|
||||||
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
NB_LOG_FILE="console,/var/log/netbird/client.log" \
|
||||||
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
NB_DAEMON_ADDR="unix:///var/run/netbird.sock" \
|
||||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="5"
|
|
||||||
|
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||||
|
|
||||||
|
|||||||
@@ -23,8 +23,7 @@ ENV \
|
|||||||
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
NB_DAEMON_ADDR="unix:///var/lib/netbird/netbird.sock" \
|
||||||
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
NB_LOG_FILE="console,/var/lib/netbird/client.log" \
|
||||||
NB_DISABLE_DNS="true" \
|
NB_DISABLE_DNS="true" \
|
||||||
NB_ENTRYPOINT_SERVICE_TIMEOUT="5" \
|
NB_ENTRYPOINT_SERVICE_TIMEOUT="30"
|
||||||
NB_ENTRYPOINT_LOGIN_TIMEOUT="1"
|
|
||||||
|
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
ENTRYPOINT [ "/usr/local/bin/netbird-entrypoint.sh" ]
|
||||||
|
|
||||||
|
|||||||
@@ -124,7 +124,7 @@ func (c *Client) Run(platformFiles PlatformFiles, urlOpener URLOpener, isAndroid
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -157,7 +157,7 @@ func (c *Client) RunWithoutLogin(platformFiles PlatformFiles, dns *DNSList, dnsR
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder, false)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, slices.Clone(dns.items), dnsReadyListener, stateFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,7 +205,7 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
pi := PeerInfo{
|
pi := PeerInfo{
|
||||||
p.IP,
|
p.IP,
|
||||||
p.FQDN,
|
p.FQDN,
|
||||||
p.ConnStatus.String(),
|
int(p.ConnStatus),
|
||||||
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
PeerRoutes{routes: maps.Keys(p.GetRoutes())},
|
||||||
}
|
}
|
||||||
peerInfos[n] = pi
|
peerInfos[n] = pi
|
||||||
|
|||||||
@@ -1,10 +1,19 @@
|
|||||||
package android
|
package android
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/peer"
|
import (
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
// EnvKeyNBForceRelay Exported for Android java client
|
// EnvKeyNBForceRelay Exported for Android java client to force relay connections
|
||||||
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
EnvKeyNBForceRelay = peer.EnvKeyNBForceRelay
|
||||||
|
|
||||||
|
// EnvKeyNBLazyConn Exported for Android java client to configure lazy connection
|
||||||
|
EnvKeyNBLazyConn = lazyconn.EnvEnableLazyConn
|
||||||
|
|
||||||
|
// EnvKeyNBInactivityThreshold Exported for Android java client to configure connection inactivity threshold
|
||||||
|
EnvKeyNBInactivityThreshold = lazyconn.EnvInactivityThreshold
|
||||||
)
|
)
|
||||||
|
|
||||||
// EnvList wraps a Go map for export to Java
|
// EnvList wraps a Go map for export to Java
|
||||||
|
|||||||
@@ -2,11 +2,20 @@
|
|||||||
|
|
||||||
package android
|
package android
|
||||||
|
|
||||||
|
import "github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
|
||||||
|
// Connection status constants exported via gomobile.
|
||||||
|
const (
|
||||||
|
ConnStatusIdle = int(peer.StatusIdle)
|
||||||
|
ConnStatusConnecting = int(peer.StatusConnecting)
|
||||||
|
ConnStatusConnected = int(peer.StatusConnected)
|
||||||
|
)
|
||||||
|
|
||||||
// PeerInfo describe information about the peers. It designed for the UI usage
|
// PeerInfo describe information about the peers. It designed for the UI usage
|
||||||
type PeerInfo struct {
|
type PeerInfo struct {
|
||||||
IP string
|
IP string
|
||||||
FQDN string
|
FQDN string
|
||||||
ConnStatus string // Todo replace to enum
|
ConnStatus int
|
||||||
Routes PeerRoutes
|
Routes PeerRoutes
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -181,10 +181,11 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
if stateWasDown {
|
if stateWasDown {
|
||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||||
|
} else {
|
||||||
|
cmd.Println("netbird up")
|
||||||
|
time.Sleep(time.Second * 10)
|
||||||
}
|
}
|
||||||
cmd.Println("netbird up")
|
|
||||||
time.Sleep(time.Second * 10)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
initialLevelTrace := initialLogLevel.GetLevel() >= proto.LogLevel_TRACE
|
||||||
@@ -199,9 +200,10 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to bring service down: %v\n", status.Convert(err).Message())
|
||||||
|
} else {
|
||||||
|
cmd.Println("netbird down")
|
||||||
}
|
}
|
||||||
cmd.Println("netbird down")
|
|
||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
@@ -209,13 +211,14 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
|
if _, err := client.SetSyncResponsePersistence(cmd.Context(), &proto.SetSyncResponsePersistenceRequest{
|
||||||
Enabled: true,
|
Enabled: true,
|
||||||
}); err != nil {
|
}); err != nil {
|
||||||
return fmt.Errorf("failed to enable sync response persistence: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to enable sync response persistence: %v\n", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to bring service up: %v\n", status.Convert(err).Message())
|
||||||
|
} else {
|
||||||
|
cmd.Println("netbird up")
|
||||||
}
|
}
|
||||||
cmd.Println("netbird up")
|
|
||||||
|
|
||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
@@ -263,16 +266,18 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
if stateWasDown {
|
if stateWasDown {
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to restore service down state: %v\n", status.Convert(err).Message())
|
||||||
|
} else {
|
||||||
|
cmd.Println("netbird down")
|
||||||
}
|
}
|
||||||
cmd.Println("netbird down")
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if !initialLevelTrace {
|
if !initialLevelTrace {
|
||||||
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
if _, err := client.SetLogLevel(cmd.Context(), &proto.SetLogLevelRequest{Level: initialLogLevel.GetLevel()}); err != nil {
|
||||||
return fmt.Errorf("failed to restore log level: %v", status.Convert(err).Message())
|
cmd.PrintErrf("Failed to restore log level: %v\n", status.Convert(err).Message())
|
||||||
|
} else {
|
||||||
|
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||||
}
|
}
|
||||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||||
|
|||||||
286
client/cmd/expose.go
Normal file
286
client/cmd/expose.go
Normal file
@@ -0,0 +1,286 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"regexp"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/expose"
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
var pinRegexp = regexp.MustCompile(`^\d{6}$`)
|
||||||
|
|
||||||
|
var (
|
||||||
|
exposePin string
|
||||||
|
exposePassword string
|
||||||
|
exposeUserGroups []string
|
||||||
|
exposeDomain string
|
||||||
|
exposeNamePrefix string
|
||||||
|
exposeProtocol string
|
||||||
|
exposeExternalPort uint16
|
||||||
|
)
|
||||||
|
|
||||||
|
var exposeCmd = &cobra.Command{
|
||||||
|
Use: "expose <port>",
|
||||||
|
Short: "Expose a local port via the NetBird reverse proxy",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
Example: ` netbird expose --with-password safe-pass 8080
|
||||||
|
netbird expose --protocol tcp 5432
|
||||||
|
netbird expose --protocol tcp --with-external-port 5433 5432
|
||||||
|
netbird expose --protocol tls --with-custom-domain tls.example.com 4443`,
|
||||||
|
RunE: exposeFn,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
exposeCmd.Flags().StringVar(&exposePin, "with-pin", "", "Protect the exposed service with a 6-digit PIN (e.g. --with-pin 123456)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposePassword, "with-password", "", "Protect the exposed service with a password (e.g. --with-password my-secret)")
|
||||||
|
exposeCmd.Flags().StringSliceVar(&exposeUserGroups, "with-user-groups", nil, "Restrict access to specific user groups with SSO (e.g. --with-user-groups devops,Backend)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposeDomain, "with-custom-domain", "", "Custom domain for the exposed service, must be configured to your account (e.g. --with-custom-domain myapp.example.com)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposeNamePrefix, "with-name-prefix", "", "Prefix for the generated service name (e.g. --with-name-prefix my-app)")
|
||||||
|
exposeCmd.Flags().StringVar(&exposeProtocol, "protocol", "http", "Protocol to use: http, https, tcp, udp, or tls (e.g. --protocol tcp)")
|
||||||
|
exposeCmd.Flags().Uint16Var(&exposeExternalPort, "with-external-port", 0, "Public-facing external port on the proxy cluster (defaults to the target port for L4)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// isClusterProtocol returns true for L4/TLS protocols that reject HTTP-style auth flags.
|
||||||
|
func isClusterProtocol(protocol string) bool {
|
||||||
|
switch strings.ToLower(protocol) {
|
||||||
|
case "tcp", "udp", "tls":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isPortBasedProtocol returns true for pure port-based protocols (TCP/UDP)
|
||||||
|
// where domain display doesn't apply. TLS uses SNI so it has a domain.
|
||||||
|
func isPortBasedProtocol(protocol string) bool {
|
||||||
|
switch strings.ToLower(protocol) {
|
||||||
|
case "tcp", "udp":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractPort returns the port portion of a URL like "tcp://host:12345", or
|
||||||
|
// falls back to the given default formatted as a string.
|
||||||
|
func extractPort(serviceURL string, fallback uint16) string {
|
||||||
|
u := serviceURL
|
||||||
|
if idx := strings.Index(u, "://"); idx != -1 {
|
||||||
|
u = u[idx+3:]
|
||||||
|
}
|
||||||
|
if i := strings.LastIndex(u, ":"); i != -1 {
|
||||||
|
if p := u[i+1:]; p != "" {
|
||||||
|
return p
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return strconv.FormatUint(uint64(fallback), 10)
|
||||||
|
}
|
||||||
|
|
||||||
|
// resolveExternalPort returns the effective external port, defaulting to the target port.
|
||||||
|
func resolveExternalPort(targetPort uint64) uint16 {
|
||||||
|
if exposeExternalPort != 0 {
|
||||||
|
return exposeExternalPort
|
||||||
|
}
|
||||||
|
return uint16(targetPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
func validateExposeFlags(cmd *cobra.Command, portStr string) (uint64, error) {
|
||||||
|
port, err := strconv.ParseUint(portStr, 10, 32)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid port number: %s", portStr)
|
||||||
|
}
|
||||||
|
if port == 0 || port > 65535 {
|
||||||
|
return 0, fmt.Errorf("invalid port number: must be between 1 and 65535")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !isProtocolValid(exposeProtocol) {
|
||||||
|
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", exposeProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
if isClusterProtocol(exposeProtocol) {
|
||||||
|
if exposePin != "" || exposePassword != "" || len(exposeUserGroups) > 0 {
|
||||||
|
return 0, fmt.Errorf("auth flags (--with-pin, --with-password, --with-user-groups) are not supported for %s protocol", exposeProtocol)
|
||||||
|
}
|
||||||
|
} else if cmd.Flags().Changed("with-external-port") {
|
||||||
|
return 0, fmt.Errorf("--with-external-port is not supported for %s protocol", exposeProtocol)
|
||||||
|
}
|
||||||
|
|
||||||
|
if exposePin != "" && !pinRegexp.MatchString(exposePin) {
|
||||||
|
return 0, fmt.Errorf("invalid pin: must be exactly 6 digits")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flags().Changed("with-password") && exposePassword == "" {
|
||||||
|
return 0, fmt.Errorf("password cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flags().Changed("with-user-groups") && len(exposeUserGroups) == 0 {
|
||||||
|
return 0, fmt.Errorf("user groups cannot be empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
return port, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func isProtocolValid(exposeProtocol string) bool {
|
||||||
|
switch strings.ToLower(exposeProtocol) {
|
||||||
|
case "http", "https", "tcp", "udp", "tls":
|
||||||
|
return true
|
||||||
|
default:
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func exposeFn(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(rootCmd)
|
||||||
|
|
||||||
|
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||||
|
log.Errorf("failed initializing log %v", err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Root().SilenceUsage = false
|
||||||
|
|
||||||
|
port, err := validateExposeFlags(cmd, args[0])
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Root().SilenceUsage = true
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
go func() {
|
||||||
|
<-sigCh
|
||||||
|
cancel()
|
||||||
|
}()
|
||||||
|
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("connect to daemon: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Debugf("failed to close daemon connection: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
|
protocol, err := toExposeProtocol(exposeProtocol)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
req := &proto.ExposeServiceRequest{
|
||||||
|
Port: uint32(port),
|
||||||
|
Protocol: protocol,
|
||||||
|
Pin: exposePin,
|
||||||
|
Password: exposePassword,
|
||||||
|
UserGroups: exposeUserGroups,
|
||||||
|
Domain: exposeDomain,
|
||||||
|
NamePrefix: exposeNamePrefix,
|
||||||
|
}
|
||||||
|
if isClusterProtocol(exposeProtocol) {
|
||||||
|
req.ListenPort = uint32(resolveExternalPort(port))
|
||||||
|
}
|
||||||
|
|
||||||
|
stream, err := client.ExposeService(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("expose service: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := handleExposeReady(cmd, stream, port); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
return waitForExposeEvents(cmd, ctx, stream)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toExposeProtocol(exposeProtocol string) (proto.ExposeProtocol, error) {
|
||||||
|
p, err := expose.ParseProtocolType(exposeProtocol)
|
||||||
|
if err != nil {
|
||||||
|
return 0, fmt.Errorf("invalid protocol: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
switch p {
|
||||||
|
case expose.ProtocolHTTP:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_HTTP, nil
|
||||||
|
case expose.ProtocolHTTPS:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_HTTPS, nil
|
||||||
|
case expose.ProtocolTCP:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_TCP, nil
|
||||||
|
case expose.ProtocolUDP:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_UDP, nil
|
||||||
|
case expose.ProtocolTLS:
|
||||||
|
return proto.ExposeProtocol_EXPOSE_TLS, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unhandled protocol type: %d", p)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func handleExposeReady(cmd *cobra.Command, stream proto.DaemonService_ExposeServiceClient, port uint64) error {
|
||||||
|
event, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("receive expose event: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ready, ok := event.Event.(*proto.ExposeServiceEvent_Ready)
|
||||||
|
if !ok {
|
||||||
|
return fmt.Errorf("unexpected expose event: %T", event.Event)
|
||||||
|
}
|
||||||
|
printExposeReady(cmd, ready.Ready, port)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printExposeReady(cmd *cobra.Command, r *proto.ExposeServiceReady, port uint64) {
|
||||||
|
cmd.Println("Service exposed successfully!")
|
||||||
|
cmd.Printf(" Name: %s\n", r.ServiceName)
|
||||||
|
if r.ServiceUrl != "" {
|
||||||
|
cmd.Printf(" URL: %s\n", r.ServiceUrl)
|
||||||
|
}
|
||||||
|
if r.Domain != "" && !isPortBasedProtocol(exposeProtocol) {
|
||||||
|
cmd.Printf(" Domain: %s\n", r.Domain)
|
||||||
|
}
|
||||||
|
cmd.Printf(" Protocol: %s\n", exposeProtocol)
|
||||||
|
cmd.Printf(" Internal: %d\n", port)
|
||||||
|
if isClusterProtocol(exposeProtocol) {
|
||||||
|
cmd.Printf(" External: %s\n", extractPort(r.ServiceUrl, resolveExternalPort(port)))
|
||||||
|
}
|
||||||
|
if r.PortAutoAssigned && exposeExternalPort != 0 {
|
||||||
|
cmd.Printf("\n Note: requested port %d was reassigned\n", exposeExternalPort)
|
||||||
|
}
|
||||||
|
cmd.Println()
|
||||||
|
cmd.Println("Press Ctrl+C to stop exposing.")
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForExposeEvents(cmd *cobra.Command, ctx context.Context, stream proto.DaemonService_ExposeServiceClient) error {
|
||||||
|
for {
|
||||||
|
_, err := stream.Recv()
|
||||||
|
if err != nil {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
cmd.Println("\nService stopped.")
|
||||||
|
//nolint:nilerr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if errors.Is(err, io.EOF) {
|
||||||
|
return fmt.Errorf("connection to daemon closed unexpectedly")
|
||||||
|
}
|
||||||
|
return fmt.Errorf("stream error: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -282,13 +282,9 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *profileman
|
|||||||
}
|
}
|
||||||
defer authClient.Close()
|
defer authClient.Close()
|
||||||
|
|
||||||
needsLogin := false
|
needsLogin, err := authClient.IsLoginRequired(ctx)
|
||||||
|
if err != nil {
|
||||||
err, isAuthError := authClient.Login(ctx, "", "")
|
return fmt.Errorf("check login required: %v", err)
|
||||||
if isAuthError {
|
|
||||||
needsLogin = true
|
|
||||||
} else if err != nil {
|
|
||||||
return fmt.Errorf("login check failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
jwtToken := ""
|
jwtToken := ""
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"google.golang.org/grpc"
|
"google.golang.org/grpc"
|
||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
|
daddr "github.com/netbirdio/netbird/client/internal/daemonaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -80,6 +81,15 @@ var (
|
|||||||
Short: "",
|
Short: "",
|
||||||
Long: "",
|
Long: "",
|
||||||
SilenceUsage: true,
|
SilenceUsage: true,
|
||||||
|
PersistentPreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
SetFlagsFromEnvVars(cmd.Root())
|
||||||
|
|
||||||
|
// Don't resolve for service commands — they create the socket, not connect to it.
|
||||||
|
if !isServiceCmd(cmd) {
|
||||||
|
daemonAddr = daddr.ResolveUnixDaemonAddr(daemonAddr)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -144,6 +154,7 @@ func init() {
|
|||||||
rootCmd.AddCommand(forwardingRulesCmd)
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
rootCmd.AddCommand(profileCmd)
|
rootCmd.AddCommand(profileCmd)
|
||||||
|
rootCmd.AddCommand(exposeCmd)
|
||||||
|
|
||||||
networksCMD.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
@@ -385,7 +396,6 @@ func migrateToNetbird(oldPath, newPath string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
conn, err := DialClientGRPCServer(cmd.Context(), daemonAddr)
|
||||||
@@ -398,3 +408,13 @@ func getClient(cmd *cobra.Command) (*grpc.ClientConn, error) {
|
|||||||
|
|
||||||
return conn, nil
|
return conn, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// isServiceCmd returns true if cmd is the "service" command or a child of it.
|
||||||
|
func isServiceCmd(cmd *cobra.Command) bool {
|
||||||
|
for c := cmd; c != nil; c = c.Parent() {
|
||||||
|
if c.Name() == "service" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -41,7 +41,7 @@ func init() {
|
|||||||
defaultServiceName = "Netbird"
|
defaultServiceName = "Netbird"
|
||||||
}
|
}
|
||||||
|
|
||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd)
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd, svcStatusCmd, installCmd, uninstallCmd, reconfigureCmd, resetParamsCmd)
|
||||||
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
serviceCmd.PersistentFlags().BoolVar(&profilesDisabled, "disable-profiles", false, "Disables profiles feature. If enabled, the client will not be able to change or edit any profile. To persist this setting, use: netbird service install --disable-profiles")
|
||||||
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
serviceCmd.PersistentFlags().BoolVar(&updateSettingsDisabled, "disable-update-settings", false, "Disables update settings feature. If enabled, the client will not be able to change or edit any settings. To persist this setting, use: netbird service install --disable-update-settings")
|
||||||
|
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ func (p *program) Stop(srv service.Service) error {
|
|||||||
|
|
||||||
// Common setup for service control commands
|
// Common setup for service control commands
|
||||||
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
func setupServiceControlCommand(cmd *cobra.Command, ctx context.Context, cancel context.CancelFunc) (service.Service, error) {
|
||||||
SetFlagsFromEnvVars(rootCmd)
|
// rootCmd env vars are already applied by PersistentPreRunE.
|
||||||
SetFlagsFromEnvVars(serviceCmd)
|
SetFlagsFromEnvVars(serviceCmd)
|
||||||
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|||||||
@@ -119,6 +119,10 @@ var installCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
svcConfig, err := createServiceConfigForInstall()
|
svcConfig, err := createServiceConfigForInstall()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -136,6 +140,10 @@ var installCmd = &cobra.Command{
|
|||||||
return fmt.Errorf("install service: %w", err)
|
return fmt.Errorf("install service: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Println("NetBird service has been installed")
|
cmd.Println("NetBird service has been installed")
|
||||||
return nil
|
return nil
|
||||||
},
|
},
|
||||||
@@ -187,6 +195,10 @@ This command will temporarily stop the service, update its configuration, and re
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := loadAndApplyServiceParams(cmd); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to load saved service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
wasRunning, err := isServiceRunning()
|
wasRunning, err := isServiceRunning()
|
||||||
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
if err != nil && !errors.Is(err, ErrGetServiceStatus) {
|
||||||
return fmt.Errorf("check service status: %w", err)
|
return fmt.Errorf("check service status: %w", err)
|
||||||
@@ -222,6 +234,10 @@ This command will temporarily stop the service, update its configuration, and re
|
|||||||
return fmt.Errorf("install service with new config: %w", err)
|
return fmt.Errorf("install service with new config: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := saveServiceParams(currentServiceParams()); err != nil {
|
||||||
|
cmd.PrintErrf("Warning: failed to save service params: %v\n", err)
|
||||||
|
}
|
||||||
|
|
||||||
if wasRunning {
|
if wasRunning {
|
||||||
cmd.Println("Starting NetBird service...")
|
cmd.Println("Starting NetBird service...")
|
||||||
if err := s.Start(); err != nil {
|
if err := s.Start(); err != nil {
|
||||||
|
|||||||
201
client/cmd/service_params.go
Normal file
201
client/cmd/service_params.go
Normal file
@@ -0,0 +1,201 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"fmt"
|
||||||
|
"maps"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
"github.com/netbirdio/netbird/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
const serviceParamsFile = "service.json"
|
||||||
|
|
||||||
|
// serviceParams holds install-time service parameters that persist across
|
||||||
|
// uninstall/reinstall cycles. Saved to <stateDir>/service.json.
|
||||||
|
type serviceParams struct {
|
||||||
|
LogLevel string `json:"log_level"`
|
||||||
|
DaemonAddr string `json:"daemon_addr"`
|
||||||
|
ManagementURL string `json:"management_url,omitempty"`
|
||||||
|
ConfigPath string `json:"config_path,omitempty"`
|
||||||
|
LogFiles []string `json:"log_files,omitempty"`
|
||||||
|
DisableProfiles bool `json:"disable_profiles,omitempty"`
|
||||||
|
DisableUpdateSettings bool `json:"disable_update_settings,omitempty"`
|
||||||
|
ServiceEnvVars map[string]string `json:"service_env_vars,omitempty"`
|
||||||
|
}
|
||||||
|
|
||||||
|
// serviceParamsPath returns the path to the service params file.
|
||||||
|
func serviceParamsPath() string {
|
||||||
|
return filepath.Join(configs.StateDir, serviceParamsFile)
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadServiceParams reads saved service parameters from disk.
|
||||||
|
// Returns nil with no error if the file does not exist.
|
||||||
|
func loadServiceParams() (*serviceParams, error) {
|
||||||
|
path := serviceParamsPath()
|
||||||
|
|
||||||
|
data, err := os.ReadFile(path)
|
||||||
|
if err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
return nil, nil //nolint:nilnil
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("read service params %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var params serviceParams
|
||||||
|
if err := json.Unmarshal(data, ¶ms); err != nil {
|
||||||
|
return nil, fmt.Errorf("parse service params %s: %w", path, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return ¶ms, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// saveServiceParams writes current service parameters to disk atomically
|
||||||
|
// with restricted permissions.
|
||||||
|
func saveServiceParams(params *serviceParams) error {
|
||||||
|
path := serviceParamsPath()
|
||||||
|
if err := util.WriteJsonWithRestrictedPermission(context.Background(), path, params); err != nil {
|
||||||
|
return fmt.Errorf("save service params: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// currentServiceParams captures the current state of all package-level
|
||||||
|
// variables into a serviceParams struct.
|
||||||
|
func currentServiceParams() *serviceParams {
|
||||||
|
params := &serviceParams{
|
||||||
|
LogLevel: logLevel,
|
||||||
|
DaemonAddr: daemonAddr,
|
||||||
|
ManagementURL: managementURL,
|
||||||
|
ConfigPath: configPath,
|
||||||
|
LogFiles: logFiles,
|
||||||
|
DisableProfiles: profilesDisabled,
|
||||||
|
DisableUpdateSettings: updateSettingsDisabled,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(serviceEnvVars) > 0 {
|
||||||
|
parsed, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
if err == nil && len(parsed) > 0 {
|
||||||
|
params.ServiceEnvVars = parsed
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return params
|
||||||
|
}
|
||||||
|
|
||||||
|
// loadAndApplyServiceParams loads saved params from disk and applies them
|
||||||
|
// to any flags that were not explicitly set.
|
||||||
|
func loadAndApplyServiceParams(cmd *cobra.Command) error {
|
||||||
|
params, err := loadServiceParams()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
applyServiceParams(cmd, params)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyServiceParams merges saved parameters into package-level variables
|
||||||
|
// for any flag that was not explicitly set by the user (via CLI or env var).
|
||||||
|
// Flags that were Changed() are left untouched.
|
||||||
|
func applyServiceParams(cmd *cobra.Command, params *serviceParams) {
|
||||||
|
if params == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// For fields with non-empty defaults (log-level, daemon-addr), keep the
|
||||||
|
// != "" guard so that an older service.json missing the field doesn't
|
||||||
|
// clobber the default with an empty string.
|
||||||
|
if !rootCmd.PersistentFlags().Changed("log-level") && params.LogLevel != "" {
|
||||||
|
logLevel = params.LogLevel
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rootCmd.PersistentFlags().Changed("daemon-addr") && params.DaemonAddr != "" {
|
||||||
|
daemonAddr = params.DaemonAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
// For optional fields where empty means "use default", always apply so
|
||||||
|
// that an explicit clear (--management-url "") persists across reinstalls.
|
||||||
|
if !rootCmd.PersistentFlags().Changed("management-url") {
|
||||||
|
managementURL = params.ManagementURL
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rootCmd.PersistentFlags().Changed("config") {
|
||||||
|
configPath = params.ConfigPath
|
||||||
|
}
|
||||||
|
|
||||||
|
if !rootCmd.PersistentFlags().Changed("log-file") {
|
||||||
|
logFiles = params.LogFiles
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceCmd.PersistentFlags().Changed("disable-profiles") {
|
||||||
|
profilesDisabled = params.DisableProfiles
|
||||||
|
}
|
||||||
|
|
||||||
|
if !serviceCmd.PersistentFlags().Changed("disable-update-settings") {
|
||||||
|
updateSettingsDisabled = params.DisableUpdateSettings
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyServiceEnvParams merges saved service environment variables.
|
||||||
|
// If --service-env was explicitly set, explicit values win on key conflict
|
||||||
|
// but saved keys not in the explicit set are carried over.
|
||||||
|
// If --service-env was not set, saved env vars are used entirely.
|
||||||
|
func applyServiceEnvParams(cmd *cobra.Command, params *serviceParams) {
|
||||||
|
if len(params.ServiceEnvVars) == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !cmd.Flags().Changed("service-env") {
|
||||||
|
// No explicit env vars: rebuild serviceEnvVars from saved params.
|
||||||
|
serviceEnvVars = envMapToSlice(params.ServiceEnvVars)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Explicit env vars were provided: merge saved values underneath.
|
||||||
|
explicit, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
if err != nil {
|
||||||
|
cmd.PrintErrf("Warning: parse explicit service env vars for merge: %v\n", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
merged := make(map[string]string, len(params.ServiceEnvVars)+len(explicit))
|
||||||
|
maps.Copy(merged, params.ServiceEnvVars)
|
||||||
|
maps.Copy(merged, explicit) // explicit wins on conflict
|
||||||
|
serviceEnvVars = envMapToSlice(merged)
|
||||||
|
}
|
||||||
|
|
||||||
|
var resetParamsCmd = &cobra.Command{
|
||||||
|
Use: "reset-params",
|
||||||
|
Short: "Remove saved service install parameters",
|
||||||
|
Long: "Removes the saved service.json file so the next install uses default parameters.",
|
||||||
|
RunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
path := serviceParamsPath()
|
||||||
|
if err := os.Remove(path); err != nil {
|
||||||
|
if os.IsNotExist(err) {
|
||||||
|
cmd.Println("No saved service parameters found")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return fmt.Errorf("remove service params: %w", err)
|
||||||
|
}
|
||||||
|
cmd.Printf("Removed saved service parameters (%s)\n", path)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// envMapToSlice converts a map of env vars to a KEY=VALUE slice.
|
||||||
|
func envMapToSlice(m map[string]string) []string {
|
||||||
|
s := make([]string, 0, len(m))
|
||||||
|
for k, v := range m {
|
||||||
|
s = append(s, k+"="+v)
|
||||||
|
}
|
||||||
|
return s
|
||||||
|
}
|
||||||
523
client/cmd/service_params_test.go
Normal file
523
client/cmd/service_params_test.go
Normal file
@@ -0,0 +1,523 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"go/ast"
|
||||||
|
"go/parser"
|
||||||
|
"go/token"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"github.com/spf13/pflag"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/configs"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestServiceParamsPath(t *testing.T) {
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
|
||||||
|
configs.StateDir = "/var/lib/netbird"
|
||||||
|
assert.Equal(t, filepath.Join("/var/lib/netbird", "service.json"), serviceParamsPath())
|
||||||
|
|
||||||
|
configs.StateDir = "/custom/state"
|
||||||
|
assert.Equal(t, filepath.Join("/custom/state", "service.json"), serviceParamsPath())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSaveAndLoadServiceParams(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
configs.StateDir = tmpDir
|
||||||
|
|
||||||
|
params := &serviceParams{
|
||||||
|
LogLevel: "debug",
|
||||||
|
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||||
|
ManagementURL: "https://my.server.com",
|
||||||
|
ConfigPath: "/etc/netbird/config.json",
|
||||||
|
LogFiles: []string{"/var/log/netbird/client.log", "console"},
|
||||||
|
DisableProfiles: true,
|
||||||
|
DisableUpdateSettings: false,
|
||||||
|
ServiceEnvVars: map[string]string{"NB_LOG_FORMAT": "json", "CUSTOM": "val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := saveServiceParams(params)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the file exists and is valid JSON.
|
||||||
|
data, err := os.ReadFile(filepath.Join(tmpDir, "service.json"))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, json.Valid(data))
|
||||||
|
|
||||||
|
loaded, err := loadServiceParams()
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, loaded)
|
||||||
|
|
||||||
|
assert.Equal(t, params.LogLevel, loaded.LogLevel)
|
||||||
|
assert.Equal(t, params.DaemonAddr, loaded.DaemonAddr)
|
||||||
|
assert.Equal(t, params.ManagementURL, loaded.ManagementURL)
|
||||||
|
assert.Equal(t, params.ConfigPath, loaded.ConfigPath)
|
||||||
|
assert.Equal(t, params.LogFiles, loaded.LogFiles)
|
||||||
|
assert.Equal(t, params.DisableProfiles, loaded.DisableProfiles)
|
||||||
|
assert.Equal(t, params.DisableUpdateSettings, loaded.DisableUpdateSettings)
|
||||||
|
assert.Equal(t, params.ServiceEnvVars, loaded.ServiceEnvVars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadServiceParams_FileNotExists(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
configs.StateDir = tmpDir
|
||||||
|
|
||||||
|
params, err := loadServiceParams()
|
||||||
|
assert.NoError(t, err)
|
||||||
|
assert.Nil(t, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLoadServiceParams_InvalidJSON(t *testing.T) {
|
||||||
|
tmpDir := t.TempDir()
|
||||||
|
|
||||||
|
original := configs.StateDir
|
||||||
|
t.Cleanup(func() { configs.StateDir = original })
|
||||||
|
configs.StateDir = tmpDir
|
||||||
|
|
||||||
|
err := os.WriteFile(filepath.Join(tmpDir, "service.json"), []byte("not json"), 0600)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
params, err := loadServiceParams()
|
||||||
|
assert.Error(t, err)
|
||||||
|
assert.Nil(t, params)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestCurrentServiceParams(t *testing.T) {
|
||||||
|
origLogLevel := logLevel
|
||||||
|
origDaemonAddr := daemonAddr
|
||||||
|
origManagementURL := managementURL
|
||||||
|
origConfigPath := configPath
|
||||||
|
origLogFiles := logFiles
|
||||||
|
origProfilesDisabled := profilesDisabled
|
||||||
|
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() {
|
||||||
|
logLevel = origLogLevel
|
||||||
|
daemonAddr = origDaemonAddr
|
||||||
|
managementURL = origManagementURL
|
||||||
|
configPath = origConfigPath
|
||||||
|
logFiles = origLogFiles
|
||||||
|
profilesDisabled = origProfilesDisabled
|
||||||
|
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||||
|
serviceEnvVars = origServiceEnvVars
|
||||||
|
})
|
||||||
|
|
||||||
|
logLevel = "trace"
|
||||||
|
daemonAddr = "tcp://127.0.0.1:9999"
|
||||||
|
managementURL = "https://mgmt.example.com"
|
||||||
|
configPath = "/tmp/test-config.json"
|
||||||
|
logFiles = []string{"/tmp/test.log"}
|
||||||
|
profilesDisabled = true
|
||||||
|
updateSettingsDisabled = true
|
||||||
|
serviceEnvVars = []string{"FOO=bar", "BAZ=qux"}
|
||||||
|
|
||||||
|
params := currentServiceParams()
|
||||||
|
|
||||||
|
assert.Equal(t, "trace", params.LogLevel)
|
||||||
|
assert.Equal(t, "tcp://127.0.0.1:9999", params.DaemonAddr)
|
||||||
|
assert.Equal(t, "https://mgmt.example.com", params.ManagementURL)
|
||||||
|
assert.Equal(t, "/tmp/test-config.json", params.ConfigPath)
|
||||||
|
assert.Equal(t, []string{"/tmp/test.log"}, params.LogFiles)
|
||||||
|
assert.True(t, params.DisableProfiles)
|
||||||
|
assert.True(t, params.DisableUpdateSettings)
|
||||||
|
assert.Equal(t, map[string]string{"FOO": "bar", "BAZ": "qux"}, params.ServiceEnvVars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_OnlyUnchangedFlags(t *testing.T) {
|
||||||
|
origLogLevel := logLevel
|
||||||
|
origDaemonAddr := daemonAddr
|
||||||
|
origManagementURL := managementURL
|
||||||
|
origConfigPath := configPath
|
||||||
|
origLogFiles := logFiles
|
||||||
|
origProfilesDisabled := profilesDisabled
|
||||||
|
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() {
|
||||||
|
logLevel = origLogLevel
|
||||||
|
daemonAddr = origDaemonAddr
|
||||||
|
managementURL = origManagementURL
|
||||||
|
configPath = origConfigPath
|
||||||
|
logFiles = origLogFiles
|
||||||
|
profilesDisabled = origProfilesDisabled
|
||||||
|
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||||
|
serviceEnvVars = origServiceEnvVars
|
||||||
|
})
|
||||||
|
|
||||||
|
// Reset all flags to defaults.
|
||||||
|
logLevel = "info"
|
||||||
|
daemonAddr = "unix:///var/run/netbird.sock"
|
||||||
|
managementURL = ""
|
||||||
|
configPath = "/etc/netbird/config.json"
|
||||||
|
logFiles = []string{"/var/log/netbird/client.log"}
|
||||||
|
profilesDisabled = false
|
||||||
|
updateSettingsDisabled = false
|
||||||
|
serviceEnvVars = nil
|
||||||
|
|
||||||
|
// Reset Changed state on all relevant flags.
|
||||||
|
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate user explicitly setting --log-level via CLI.
|
||||||
|
logLevel = "warn"
|
||||||
|
require.NoError(t, rootCmd.PersistentFlags().Set("log-level", "warn"))
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
LogLevel: "debug",
|
||||||
|
DaemonAddr: "tcp://127.0.0.1:5555",
|
||||||
|
ManagementURL: "https://saved.example.com",
|
||||||
|
ConfigPath: "/saved/config.json",
|
||||||
|
LogFiles: []string{"/saved/client.log"},
|
||||||
|
DisableProfiles: true,
|
||||||
|
DisableUpdateSettings: true,
|
||||||
|
ServiceEnvVars: map[string]string{"SAVED_KEY": "saved_val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
applyServiceParams(cmd, saved)
|
||||||
|
|
||||||
|
// log-level was Changed, so it should keep "warn", not use saved "debug".
|
||||||
|
assert.Equal(t, "warn", logLevel)
|
||||||
|
|
||||||
|
// All other fields were not Changed, so they should use saved values.
|
||||||
|
assert.Equal(t, "tcp://127.0.0.1:5555", daemonAddr)
|
||||||
|
assert.Equal(t, "https://saved.example.com", managementURL)
|
||||||
|
assert.Equal(t, "/saved/config.json", configPath)
|
||||||
|
assert.Equal(t, []string{"/saved/client.log"}, logFiles)
|
||||||
|
assert.True(t, profilesDisabled)
|
||||||
|
assert.True(t, updateSettingsDisabled)
|
||||||
|
assert.Equal(t, []string{"SAVED_KEY=saved_val"}, serviceEnvVars)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_BooleanRevertToFalse(t *testing.T) {
|
||||||
|
origProfilesDisabled := profilesDisabled
|
||||||
|
origUpdateSettingsDisabled := updateSettingsDisabled
|
||||||
|
t.Cleanup(func() {
|
||||||
|
profilesDisabled = origProfilesDisabled
|
||||||
|
updateSettingsDisabled = origUpdateSettingsDisabled
|
||||||
|
})
|
||||||
|
|
||||||
|
// Simulate current state where booleans are true (e.g. set by previous install).
|
||||||
|
profilesDisabled = true
|
||||||
|
updateSettingsDisabled = true
|
||||||
|
|
||||||
|
// Reset Changed state so flags appear unset.
|
||||||
|
serviceCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
|
||||||
|
// Saved params have both as false.
|
||||||
|
saved := &serviceParams{
|
||||||
|
DisableProfiles: false,
|
||||||
|
DisableUpdateSettings: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
applyServiceParams(cmd, saved)
|
||||||
|
|
||||||
|
assert.False(t, profilesDisabled, "saved false should override current true")
|
||||||
|
assert.False(t, updateSettingsDisabled, "saved false should override current true")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_ClearManagementURL(t *testing.T) {
|
||||||
|
origManagementURL := managementURL
|
||||||
|
t.Cleanup(func() { managementURL = origManagementURL })
|
||||||
|
|
||||||
|
managementURL = "https://leftover.example.com"
|
||||||
|
|
||||||
|
// Simulate saved params where management URL was explicitly cleared.
|
||||||
|
saved := &serviceParams{
|
||||||
|
LogLevel: "info",
|
||||||
|
DaemonAddr: "unix:///var/run/netbird.sock",
|
||||||
|
// ManagementURL intentionally empty: was cleared with --management-url "".
|
||||||
|
}
|
||||||
|
|
||||||
|
rootCmd.PersistentFlags().VisitAll(func(f *pflag.Flag) {
|
||||||
|
f.Changed = false
|
||||||
|
})
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
applyServiceParams(cmd, saved)
|
||||||
|
|
||||||
|
assert.Equal(t, "", managementURL, "saved empty management URL should clear the current value")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceParams_NilParams(t *testing.T) {
|
||||||
|
origLogLevel := logLevel
|
||||||
|
t.Cleanup(func() { logLevel = origLogLevel })
|
||||||
|
|
||||||
|
logLevel = "info"
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
|
||||||
|
// Should be a no-op.
|
||||||
|
applyServiceParams(cmd, nil)
|
||||||
|
assert.Equal(t, "info", logLevel)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceEnvParams_MergeExplicitAndSaved(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
// Set up a command with --service-env marked as Changed.
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
require.NoError(t, cmd.Flags().Set("service-env", "EXPLICIT=yes,OVERLAP=explicit"))
|
||||||
|
|
||||||
|
serviceEnvVars = []string{"EXPLICIT=yes", "OVERLAP=explicit"}
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
ServiceEnvVars: map[string]string{
|
||||||
|
"SAVED": "val",
|
||||||
|
"OVERLAP": "saved",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, saved)
|
||||||
|
|
||||||
|
// Parse result for easier assertion.
|
||||||
|
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.Equal(t, "yes", result["EXPLICIT"])
|
||||||
|
assert.Equal(t, "val", result["SAVED"])
|
||||||
|
// Explicit wins on conflict.
|
||||||
|
assert.Equal(t, "explicit", result["OVERLAP"])
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestApplyServiceEnvParams_NotChanged(t *testing.T) {
|
||||||
|
origServiceEnvVars := serviceEnvVars
|
||||||
|
t.Cleanup(func() { serviceEnvVars = origServiceEnvVars })
|
||||||
|
|
||||||
|
serviceEnvVars = nil
|
||||||
|
|
||||||
|
cmd := &cobra.Command{}
|
||||||
|
cmd.Flags().StringSlice("service-env", nil, "")
|
||||||
|
|
||||||
|
saved := &serviceParams{
|
||||||
|
ServiceEnvVars: map[string]string{"FROM_SAVED": "val"},
|
||||||
|
}
|
||||||
|
|
||||||
|
applyServiceEnvParams(cmd, saved)
|
||||||
|
|
||||||
|
result, err := parseServiceEnvVars(serviceEnvVars)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, map[string]string{"FROM_SAVED": "val"}, result)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceParams_FieldsCoveredInFunctions ensures that all serviceParams fields are
|
||||||
|
// referenced in both currentServiceParams() and applyServiceParams(). If a new field is
|
||||||
|
// added to serviceParams but not wired into these functions, this test fails.
|
||||||
|
func TestServiceParams_FieldsCoveredInFunctions(t *testing.T) {
|
||||||
|
fset := token.NewFileSet()
|
||||||
|
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Collect all JSON field names from the serviceParams struct.
|
||||||
|
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||||
|
require.NotEmpty(t, structFields, "failed to find serviceParams struct fields")
|
||||||
|
|
||||||
|
// Collect field names referenced in currentServiceParams and applyServiceParams.
|
||||||
|
currentFields := extractFuncFieldRefs(t, file, "currentServiceParams", structFields)
|
||||||
|
applyFields := extractFuncFieldRefs(t, file, "applyServiceParams", structFields)
|
||||||
|
// applyServiceEnvParams handles ServiceEnvVars indirectly.
|
||||||
|
applyEnvFields := extractFuncFieldRefs(t, file, "applyServiceEnvParams", structFields)
|
||||||
|
for k, v := range applyEnvFields {
|
||||||
|
applyFields[k] = v
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, field := range structFields {
|
||||||
|
assert.Contains(t, currentFields, field,
|
||||||
|
"serviceParams field %q is not captured in currentServiceParams()", field)
|
||||||
|
assert.Contains(t, applyFields, field,
|
||||||
|
"serviceParams field %q is not restored in applyServiceParams()/applyServiceEnvParams()", field)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestServiceParams_BuildArgsCoversAllFlags ensures that buildServiceArguments references
|
||||||
|
// all serviceParams fields that should become CLI args. ServiceEnvVars is excluded because
|
||||||
|
// it flows through newSVCConfig() EnvVars, not CLI args.
|
||||||
|
func TestServiceParams_BuildArgsCoversAllFlags(t *testing.T) {
|
||||||
|
fset := token.NewFileSet()
|
||||||
|
file, err := parser.ParseFile(fset, "service_params.go", nil, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
structFields := extractStructJSONFields(t, file, "serviceParams")
|
||||||
|
require.NotEmpty(t, structFields)
|
||||||
|
|
||||||
|
installerFile, err := parser.ParseFile(fset, "service_installer.go", nil, 0)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Fields that are handled outside of buildServiceArguments (env vars go through newSVCConfig).
|
||||||
|
fieldsNotInArgs := map[string]bool{
|
||||||
|
"ServiceEnvVars": true,
|
||||||
|
}
|
||||||
|
|
||||||
|
buildFields := extractFuncGlobalRefs(t, installerFile, "buildServiceArguments")
|
||||||
|
|
||||||
|
// Forward: every struct field must appear in buildServiceArguments.
|
||||||
|
for _, field := range structFields {
|
||||||
|
if fieldsNotInArgs[field] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
globalVar := fieldToGlobalVar(field)
|
||||||
|
assert.Contains(t, buildFields, globalVar,
|
||||||
|
"serviceParams field %q (global %q) is not referenced in buildServiceArguments()", field, globalVar)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reverse: every service-related global used in buildServiceArguments must
|
||||||
|
// have a corresponding serviceParams field. This catches a developer adding
|
||||||
|
// a new flag to buildServiceArguments without adding it to the struct.
|
||||||
|
globalToField := make(map[string]string, len(structFields))
|
||||||
|
for _, field := range structFields {
|
||||||
|
globalToField[fieldToGlobalVar(field)] = field
|
||||||
|
}
|
||||||
|
// Identifiers in buildServiceArguments that are not service params
|
||||||
|
// (builtins, boilerplate, loop variables).
|
||||||
|
nonParamGlobals := map[string]bool{
|
||||||
|
"args": true, "append": true, "string": true, "_": true,
|
||||||
|
"logFile": true, // range variable over logFiles
|
||||||
|
}
|
||||||
|
for ref := range buildFields {
|
||||||
|
if nonParamGlobals[ref] {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, inStruct := globalToField[ref]
|
||||||
|
assert.True(t, inStruct,
|
||||||
|
"buildServiceArguments() references global %q which has no corresponding serviceParams field", ref)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractStructJSONFields returns field names from a named struct type.
|
||||||
|
func extractStructJSONFields(t *testing.T, file *ast.File, structName string) []string {
|
||||||
|
t.Helper()
|
||||||
|
var fields []string
|
||||||
|
ast.Inspect(file, func(n ast.Node) bool {
|
||||||
|
ts, ok := n.(*ast.TypeSpec)
|
||||||
|
if !ok || ts.Name.Name != structName {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
st, ok := ts.Type.(*ast.StructType)
|
||||||
|
if !ok {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
for _, f := range st.Fields.List {
|
||||||
|
if len(f.Names) > 0 {
|
||||||
|
fields = append(fields, f.Names[0].Name)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
return fields
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFuncFieldRefs returns which of the given field names appear inside the
|
||||||
|
// named function, either as selector expressions (params.FieldName) or as
|
||||||
|
// composite literal keys (&serviceParams{FieldName: ...}).
|
||||||
|
func extractFuncFieldRefs(t *testing.T, file *ast.File, funcName string, fields []string) map[string]bool {
|
||||||
|
t.Helper()
|
||||||
|
fieldSet := make(map[string]bool, len(fields))
|
||||||
|
for _, f := range fields {
|
||||||
|
fieldSet[f] = true
|
||||||
|
}
|
||||||
|
|
||||||
|
found := make(map[string]bool)
|
||||||
|
fn := findFuncDecl(file, funcName)
|
||||||
|
require.NotNil(t, fn, "function %s not found", funcName)
|
||||||
|
|
||||||
|
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||||
|
switch v := n.(type) {
|
||||||
|
case *ast.SelectorExpr:
|
||||||
|
if fieldSet[v.Sel.Name] {
|
||||||
|
found[v.Sel.Name] = true
|
||||||
|
}
|
||||||
|
case *ast.KeyValueExpr:
|
||||||
|
if ident, ok := v.Key.(*ast.Ident); ok && fieldSet[ident.Name] {
|
||||||
|
found[ident.Name] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return found
|
||||||
|
}
|
||||||
|
|
||||||
|
// extractFuncGlobalRefs returns all identifier names referenced in the named function body.
|
||||||
|
func extractFuncGlobalRefs(t *testing.T, file *ast.File, funcName string) map[string]bool {
|
||||||
|
t.Helper()
|
||||||
|
fn := findFuncDecl(file, funcName)
|
||||||
|
require.NotNil(t, fn, "function %s not found", funcName)
|
||||||
|
|
||||||
|
refs := make(map[string]bool)
|
||||||
|
ast.Inspect(fn.Body, func(n ast.Node) bool {
|
||||||
|
if ident, ok := n.(*ast.Ident); ok {
|
||||||
|
refs[ident.Name] = true
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
return refs
|
||||||
|
}
|
||||||
|
|
||||||
|
func findFuncDecl(file *ast.File, name string) *ast.FuncDecl {
|
||||||
|
for _, decl := range file.Decls {
|
||||||
|
fn, ok := decl.(*ast.FuncDecl)
|
||||||
|
if ok && fn.Name.Name == name {
|
||||||
|
return fn
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// fieldToGlobalVar maps serviceParams field names to the package-level variable
|
||||||
|
// names used in buildServiceArguments and applyServiceParams.
|
||||||
|
func fieldToGlobalVar(field string) string {
|
||||||
|
m := map[string]string{
|
||||||
|
"LogLevel": "logLevel",
|
||||||
|
"DaemonAddr": "daemonAddr",
|
||||||
|
"ManagementURL": "managementURL",
|
||||||
|
"ConfigPath": "configPath",
|
||||||
|
"LogFiles": "logFiles",
|
||||||
|
"DisableProfiles": "profilesDisabled",
|
||||||
|
"DisableUpdateSettings": "updateSettingsDisabled",
|
||||||
|
"ServiceEnvVars": "serviceEnvVars",
|
||||||
|
}
|
||||||
|
if v, ok := m[field]; ok {
|
||||||
|
return v
|
||||||
|
}
|
||||||
|
// Default: lowercase first letter.
|
||||||
|
return strings.ToLower(field[:1]) + field[1:]
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvMapToSlice(t *testing.T) {
|
||||||
|
m := map[string]string{"A": "1", "B": "2"}
|
||||||
|
s := envMapToSlice(m)
|
||||||
|
assert.Len(t, s, 2)
|
||||||
|
assert.Contains(t, s, "A=1")
|
||||||
|
assert.Contains(t, s, "B=2")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestEnvMapToSlice_Empty(t *testing.T) {
|
||||||
|
s := envMapToSlice(map[string]string{})
|
||||||
|
assert.Empty(t, s)
|
||||||
|
}
|
||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -6,7 +6,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
|||||||
@@ -7,7 +7,7 @@ import (
|
|||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/reposign"
|
"github.com/netbirdio/netbird/client/internal/updater/reposign"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
|
|||||||
@@ -28,6 +28,7 @@ var (
|
|||||||
ipsFilterMap map[string]struct{}
|
ipsFilterMap map[string]struct{}
|
||||||
prefixNamesFilterMap map[string]struct{}
|
prefixNamesFilterMap map[string]struct{}
|
||||||
connectionTypeFilter string
|
connectionTypeFilter string
|
||||||
|
checkFlag string
|
||||||
)
|
)
|
||||||
|
|
||||||
var statusCmd = &cobra.Command{
|
var statusCmd = &cobra.Command{
|
||||||
@@ -49,6 +50,7 @@ func init() {
|
|||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||||
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
statusCmd.PersistentFlags().StringVar(&connectionTypeFilter, "filter-by-connection-type", "", "filters the detailed output by connection type (P2P|Relayed), e.g., --filter-by-connection-type P2P")
|
||||||
|
statusCmd.PersistentFlags().StringVar(&checkFlag, "check", "", "run a health check and exit with code 0 on success, 1 on failure (live|ready|startup)")
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -56,6 +58,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
cmd.SetOut(cmd.OutOrStdout())
|
cmd.SetOut(cmd.OutOrStdout())
|
||||||
|
|
||||||
|
if checkFlag != "" {
|
||||||
|
return runHealthCheck(cmd)
|
||||||
|
}
|
||||||
|
|
||||||
err := parseFilters()
|
err := parseFilters()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -68,15 +74,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
ctx := internal.CtxInitState(cmd.Context())
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
resp, err := getStatus(ctx, false)
|
resp, err := getStatus(ctx, true, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
status := resp.GetStatus()
|
status := resp.GetStatus()
|
||||||
|
|
||||||
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
needsAuth := status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||||
status == string(internal.StatusSessionExpired) {
|
status == string(internal.StatusSessionExpired)
|
||||||
|
|
||||||
|
if needsAuth && !jsonFlag && !yamlFlag {
|
||||||
cmd.Printf("Daemon status: %s\n\n"+
|
cmd.Printf("Daemon status: %s\n\n"+
|
||||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||||
" netbird up \n\n"+
|
" netbird up \n\n"+
|
||||||
@@ -99,7 +107,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
profName = activeProf.Name
|
profName = activeProf.Name
|
||||||
}
|
}
|
||||||
|
|
||||||
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), anonymizeFlag, resp.GetDaemonVersion(), statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap, connectionTypeFilter, profName)
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp.GetFullStatus(), nbstatus.ConvertOptions{
|
||||||
|
Anonymize: anonymizeFlag,
|
||||||
|
DaemonVersion: resp.GetDaemonVersion(),
|
||||||
|
DaemonStatus: nbstatus.ParseDaemonStatus(status),
|
||||||
|
StatusFilter: statusFilter,
|
||||||
|
PrefixNamesFilter: prefixNamesFilter,
|
||||||
|
PrefixNamesFilterMap: prefixNamesFilterMap,
|
||||||
|
IPsFilter: ipsFilterMap,
|
||||||
|
ConnectionTypeFilter: connectionTypeFilter,
|
||||||
|
ProfileName: profName,
|
||||||
|
})
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
@@ -121,7 +139,7 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
func getStatus(ctx context.Context, fullPeerStatus bool, shouldRunProbes bool) (*proto.StatusResponse, error) {
|
||||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
//nolint
|
//nolint
|
||||||
@@ -131,7 +149,7 @@ func getStatus(ctx context.Context, shouldRunProbes bool) (*proto.StatusResponse
|
|||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: shouldRunProbes})
|
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: fullPeerStatus, ShouldRunProbes: shouldRunProbes})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
@@ -185,6 +203,83 @@ func enableDetailFlagWhenFilterFlag() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runHealthCheck(cmd *cobra.Command) error {
|
||||||
|
check := strings.ToLower(checkFlag)
|
||||||
|
switch check {
|
||||||
|
case "live", "ready", "startup":
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown check %q, must be one of: live, ready, startup", checkFlag)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := util.InitLog(logLevel, util.LogConsole); err != nil {
|
||||||
|
return fmt.Errorf("init log: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
|
isStartup := check == "startup"
|
||||||
|
resp, err := getStatus(ctx, isStartup, false)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
switch check {
|
||||||
|
case "live":
|
||||||
|
return nil
|
||||||
|
case "ready":
|
||||||
|
return checkReadiness(resp)
|
||||||
|
case "startup":
|
||||||
|
return checkStartup(resp)
|
||||||
|
default:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkReadiness(resp *proto.StatusResponse) error {
|
||||||
|
daemonStatus := internal.StatusType(resp.GetStatus())
|
||||||
|
switch daemonStatus {
|
||||||
|
case internal.StatusIdle, internal.StatusConnecting, internal.StatusConnected:
|
||||||
|
return nil
|
||||||
|
case internal.StatusNeedsLogin, internal.StatusLoginFailed, internal.StatusSessionExpired:
|
||||||
|
return fmt.Errorf("readiness check: daemon status is %s", daemonStatus)
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("readiness check: unexpected daemon status %q", daemonStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkStartup(resp *proto.StatusResponse) error {
|
||||||
|
fullStatus := resp.GetFullStatus()
|
||||||
|
if fullStatus == nil {
|
||||||
|
return fmt.Errorf("startup check: no full status available")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fullStatus.GetManagementState().GetConnected() {
|
||||||
|
return fmt.Errorf("startup check: management not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
if !fullStatus.GetSignalState().GetConnected() {
|
||||||
|
return fmt.Errorf("startup check: signal not connected")
|
||||||
|
}
|
||||||
|
|
||||||
|
var relayCount, relaysConnected int
|
||||||
|
for _, r := range fullStatus.GetRelays() {
|
||||||
|
uri := r.GetURI()
|
||||||
|
if !strings.HasPrefix(uri, "rel://") && !strings.HasPrefix(uri, "rels://") {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
relayCount++
|
||||||
|
if r.GetAvailable() {
|
||||||
|
relaysConnected++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if relayCount > 0 && relaysConnected == 0 {
|
||||||
|
return fmt.Errorf("startup check: no relay servers available (0/%d connected)", relayCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func parseInterfaceIP(interfaceIP string) string {
|
func parseInterfaceIP(interfaceIP string) string {
|
||||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -197,7 +197,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command, activeProf *pr
|
|||||||
r := peer.NewRecorder(config.ManagementURL.String())
|
r := peer.NewRecorder(config.ManagementURL.String())
|
||||||
r.GetFullStatus()
|
r.GetFullStatus()
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r, false)
|
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||||
SetupDebugHandler(ctx, config, r, connectClient, "")
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
return connectClient.Run(nil, util.FindFirstLogPath(logFiles))
|
||||||
|
|||||||
@@ -11,7 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/auth"
|
"github.com/netbirdio/netbird/client/internal/auth"
|
||||||
@@ -21,6 +22,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
sshcommon "github.com/netbirdio/netbird/client/ssh"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,6 +33,14 @@ var (
|
|||||||
ErrConfigNotInitialized = errors.New("config not initialized")
|
ErrConfigNotInitialized = errors.New("config not initialized")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// PeerStatusConnected indicates the peer is in connected state.
|
||||||
|
PeerStatusConnected = peer.StatusConnected
|
||||||
|
)
|
||||||
|
|
||||||
|
// PeerConnStatus is a peer's connection status.
|
||||||
|
type PeerConnStatus = peer.ConnStatus
|
||||||
|
|
||||||
// Client manages a netbird embedded client instance.
|
// Client manages a netbird embedded client instance.
|
||||||
type Client struct {
|
type Client struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
@@ -71,6 +81,16 @@ type Options struct {
|
|||||||
DisableClientRoutes bool
|
DisableClientRoutes bool
|
||||||
// BlockInbound blocks all inbound connections from peers
|
// BlockInbound blocks all inbound connections from peers
|
||||||
BlockInbound bool
|
BlockInbound bool
|
||||||
|
// WireguardPort is the port for the WireGuard interface. Use 0 for a random port.
|
||||||
|
WireguardPort *int
|
||||||
|
// MTU is the MTU for the WireGuard interface.
|
||||||
|
// Valid values are in the range 576..8192 bytes.
|
||||||
|
// If non-nil, this value overrides any value stored in the config file.
|
||||||
|
// If nil, the existing config MTU (if non-zero) is preserved; otherwise it defaults to 1280.
|
||||||
|
// Set to a higher value (e.g. 1400) if carrying QUIC or other protocols that require larger datagrams.
|
||||||
|
MTU *uint16
|
||||||
|
// DNSLabels defines additional DNS labels configured in the peer.
|
||||||
|
DNSLabels []string
|
||||||
}
|
}
|
||||||
|
|
||||||
// validateCredentials checks that exactly one credential type is provided
|
// validateCredentials checks that exactly one credential type is provided
|
||||||
@@ -102,6 +122,12 @@ func New(opts Options) (*Client, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if opts.MTU != nil {
|
||||||
|
if err := iface.ValidateMTU(*opts.MTU); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid MTU: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if opts.LogOutput != nil {
|
if opts.LogOutput != nil {
|
||||||
logrus.SetOutput(opts.LogOutput)
|
logrus.SetOutput(opts.LogOutput)
|
||||||
}
|
}
|
||||||
@@ -130,9 +156,14 @@ func New(opts Options) (*Client, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var err error
|
||||||
|
var parsedLabels domain.List
|
||||||
|
if parsedLabels, err = domain.FromStringList(opts.DNSLabels); err != nil {
|
||||||
|
return nil, fmt.Errorf("invalid dns labels: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
t := true
|
t := true
|
||||||
var config *profilemanager.Config
|
var config *profilemanager.Config
|
||||||
var err error
|
|
||||||
input := profilemanager.ConfigInput{
|
input := profilemanager.ConfigInput{
|
||||||
ConfigPath: opts.ConfigPath,
|
ConfigPath: opts.ConfigPath,
|
||||||
ManagementURL: opts.ManagementURL,
|
ManagementURL: opts.ManagementURL,
|
||||||
@@ -140,6 +171,9 @@ func New(opts Options) (*Client, error) {
|
|||||||
DisableServerRoutes: &t,
|
DisableServerRoutes: &t,
|
||||||
DisableClientRoutes: &opts.DisableClientRoutes,
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
BlockInbound: &opts.BlockInbound,
|
BlockInbound: &opts.BlockInbound,
|
||||||
|
WireguardPort: opts.WireguardPort,
|
||||||
|
MTU: opts.MTU,
|
||||||
|
DNSLabels: parsedLabels,
|
||||||
}
|
}
|
||||||
if opts.ConfigPath != "" {
|
if opts.ConfigPath != "" {
|
||||||
config, err = profilemanager.UpdateOrCreateConfig(input)
|
config, err = profilemanager.UpdateOrCreateConfig(input)
|
||||||
@@ -159,6 +193,7 @@ func New(opts Options) (*Client, error) {
|
|||||||
setupKey: opts.SetupKey,
|
setupKey: opts.SetupKey,
|
||||||
jwtToken: opts.JWTToken,
|
jwtToken: opts.JWTToken,
|
||||||
config: config,
|
config: config,
|
||||||
|
recorder: peer.NewRecorder(config.ManagementURL.String()),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -180,6 +215,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
|
|
||||||
// nolint:staticcheck
|
// nolint:staticcheck
|
||||||
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
|
||||||
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
authClient, err := auth.NewAuth(ctx, c.config.PrivateKey, c.config.ManagementURL, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("create auth client: %w", err)
|
return fmt.Errorf("create auth client: %w", err)
|
||||||
@@ -189,10 +225,7 @@ func (c *Client) Start(startCtx context.Context) error {
|
|||||||
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
if err, _ := authClient.Login(ctx, c.setupKey, c.jwtToken); err != nil {
|
||||||
return fmt.Errorf("login: %w", err)
|
return fmt.Errorf("login: %w", err)
|
||||||
}
|
}
|
||||||
|
client := internal.NewConnectClient(ctx, c.config, c.recorder)
|
||||||
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
|
||||||
c.recorder = recorder
|
|
||||||
client := internal.NewConnectClient(ctx, c.config, recorder, false)
|
|
||||||
client.SetSyncResponsePersistence(true)
|
client.SetSyncResponsePersistence(true)
|
||||||
|
|
||||||
// either startup error (permanent backoff err) or nil err (successful engine up)
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
@@ -342,17 +375,38 @@ func (c *Client) NewHTTPClient() *http.Client {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Expose exposes a local service via the NetBird reverse proxy, making it accessible through a public URL.
|
||||||
|
// It returns an ExposeSession. Call Wait on the session to keep it alive.
|
||||||
|
func (c *Client) Expose(ctx context.Context, req ExposeRequest) (*ExposeSession, error) {
|
||||||
|
engine, err := c.getEngine()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
mgr := engine.GetExposeManager()
|
||||||
|
if mgr == nil {
|
||||||
|
return nil, fmt.Errorf("expose manager not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
resp, err := mgr.Expose(ctx, req)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("expose: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &ExposeSession{
|
||||||
|
Domain: resp.Domain,
|
||||||
|
ServiceName: resp.ServiceName,
|
||||||
|
ServiceURL: resp.ServiceURL,
|
||||||
|
mgr: mgr,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Status returns the current status of the client.
|
// Status returns the current status of the client.
|
||||||
func (c *Client) Status() (peer.FullStatus, error) {
|
func (c *Client) Status() (peer.FullStatus, error) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
recorder := c.recorder
|
|
||||||
connect := c.connect
|
connect := c.connect
|
||||||
c.mu.Unlock()
|
c.mu.Unlock()
|
||||||
|
|
||||||
if recorder == nil {
|
|
||||||
return peer.FullStatus{}, errors.New("client not started")
|
|
||||||
}
|
|
||||||
|
|
||||||
if connect != nil {
|
if connect != nil {
|
||||||
engine := connect.Engine()
|
engine := connect.Engine()
|
||||||
if engine != nil {
|
if engine != nil {
|
||||||
@@ -360,7 +414,7 @@ func (c *Client) Status() (peer.FullStatus, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return recorder.GetFullStatus(), nil
|
return c.recorder.GetFullStatus(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetLatestSyncResponse returns the latest sync response from the management server.
|
// GetLatestSyncResponse returns the latest sync response from the management server.
|
||||||
|
|||||||
45
client/embed/expose.go
Normal file
45
client/embed/expose.go
Normal file
@@ -0,0 +1,45 @@
|
|||||||
|
package embed
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/expose"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ExposeProtocolHTTP exposes the service as HTTP.
|
||||||
|
ExposeProtocolHTTP = expose.ProtocolHTTP
|
||||||
|
// ExposeProtocolHTTPS exposes the service as HTTPS.
|
||||||
|
ExposeProtocolHTTPS = expose.ProtocolHTTPS
|
||||||
|
// ExposeProtocolTCP exposes the service as TCP.
|
||||||
|
ExposeProtocolTCP = expose.ProtocolTCP
|
||||||
|
// ExposeProtocolUDP exposes the service as UDP.
|
||||||
|
ExposeProtocolUDP = expose.ProtocolUDP
|
||||||
|
// ExposeProtocolTLS exposes the service as TLS.
|
||||||
|
ExposeProtocolTLS = expose.ProtocolTLS
|
||||||
|
)
|
||||||
|
|
||||||
|
// ExposeRequest is a request to expose a local service via the NetBird reverse proxy.
|
||||||
|
type ExposeRequest = expose.Request
|
||||||
|
|
||||||
|
// ExposeProtocolType represents the protocol used for exposing a service.
|
||||||
|
type ExposeProtocolType = expose.ProtocolType
|
||||||
|
|
||||||
|
// ExposeSession represents an active expose session. Use Wait to block until the session ends.
|
||||||
|
type ExposeSession struct {
|
||||||
|
Domain string
|
||||||
|
ServiceName string
|
||||||
|
ServiceURL string
|
||||||
|
|
||||||
|
mgr *expose.Manager
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait blocks while keeping the expose session alive.
|
||||||
|
// It returns when ctx is cancelled or a keep-alive error occurs, then terminates the session.
|
||||||
|
func (s *ExposeSession) Wait(ctx context.Context) error {
|
||||||
|
if s == nil || s.mgr == nil {
|
||||||
|
return errors.New("expose session is not initialized")
|
||||||
|
}
|
||||||
|
return s.mgr.KeepAlive(ctx, s.Domain)
|
||||||
|
}
|
||||||
@@ -23,9 +23,10 @@ type Manager struct {
|
|||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
|
||||||
ipv4Client *iptables.IPTables
|
ipv4Client *iptables.IPTables
|
||||||
aclMgr *aclManager
|
aclMgr *aclManager
|
||||||
router *router
|
router *router
|
||||||
|
rawSupported bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
@@ -84,7 +85,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := m.initNoTrackChain(); err != nil {
|
if err := m.initNoTrackChain(); err != nil {
|
||||||
return fmt.Errorf("init notrack chain: %w", err)
|
log.Warnf("raw table not available, notrack rules will be disabled: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// persist early to ensure cleanup of chains
|
// persist early to ensure cleanup of chains
|
||||||
@@ -318,6 +319,10 @@ func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
if !m.rawSupported {
|
||||||
|
return fmt.Errorf("raw table not available")
|
||||||
|
}
|
||||||
|
|
||||||
wgPortStr := fmt.Sprintf("%d", wgPort)
|
wgPortStr := fmt.Sprintf("%d", wgPort)
|
||||||
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
proxyPortStr := fmt.Sprintf("%d", proxyPort)
|
||||||
|
|
||||||
@@ -375,12 +380,16 @@ func (m *Manager) initNoTrackChain() error {
|
|||||||
return fmt.Errorf("add prerouting jump rule: %w", err)
|
return fmt.Errorf("add prerouting jump rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.rawSupported = true
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) cleanupNoTrackChain() error {
|
func (m *Manager) cleanupNoTrackChain() error {
|
||||||
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
exists, err := m.ipv4Client.ChainExists(tableRaw, chainNameRaw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
if !m.rawSupported {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
return fmt.Errorf("check chain exists: %w", err)
|
return fmt.Errorf("check chain exists: %w", err)
|
||||||
}
|
}
|
||||||
if !exists {
|
if !exists {
|
||||||
@@ -401,6 +410,7 @@ func (m *Manager) cleanupNoTrackChain() error {
|
|||||||
return fmt.Errorf("clear and delete chain: %w", err)
|
return fmt.Errorf("clear and delete chain: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.rawSupported = false
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -95,7 +95,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := m.initNoTrackChains(workTable); err != nil {
|
if err := m.initNoTrackChains(workTable); err != nil {
|
||||||
return fmt.Errorf("init notrack chains: %w", err)
|
log.Warnf("raw priority chains not available, notrack rules will be disabled: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stateManager.RegisterState(&ShutdownState{})
|
stateManager.RegisterState(&ShutdownState{})
|
||||||
|
|||||||
@@ -483,7 +483,12 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if nftRule.Handle == 0 {
|
if nftRule.Handle == 0 {
|
||||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
log.Warnf("route rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
@@ -660,13 +665,32 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback ipset counter
|
r.rollbackRules(pair)
|
||||||
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
return fmt.Errorf("insert rules for %s: %w", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// rollbackRules cleans up unflushed rules and their set counters after a flush failure.
|
||||||
|
func (r *router) rollbackRules(pair firewall.RouterPair) {
|
||||||
|
keys := []string{
|
||||||
|
firewall.GenKey(firewall.ForwardingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, pair),
|
||||||
|
firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair)),
|
||||||
|
}
|
||||||
|
for _, key := range keys {
|
||||||
|
rule, ok := r.rules[key]
|
||||||
|
if !ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("rollback set counter for %s: %v", key, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
@@ -928,18 +952,30 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("legacy forwarding rule %s has no handle, removing stale entry", ruleKey)
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
}
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1329,65 +1365,89 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
if pair.Masquerade {
|
if pair.Masquerade {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove inverse prerouting rule: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy routing rule: %w", err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Set counters are decremented in the sub-methods above before flush. If flush fails,
|
||||||
|
// counters will be off until the next successful removal or refresh cycle.
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
// TODO: rollback set counter
|
merr = multierror.Append(merr, fmt.Errorf("flush remove nat rules %s: %w", pair.Destination, err))
|
||||||
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
||||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
rule, exists := r.rules[ruleKey]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
|
||||||
|
|
||||||
if err := r.decrementSetCounter(rule); err != nil {
|
|
||||||
return fmt.Errorf("decrement set counter: %w", err)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
log.Debugf("prerouting rule %s not found", ruleKey)
|
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("prerouting rule %s has no handle, removing stale entry", ruleKey)
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
log.Warnf("decrement set counter for stale rule %s: %v", ruleKey, err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("remove prerouting rule %s -> %s: %w", pair.Source, pair.Destination, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// refreshRulesMap refreshes the rule map with the latest rules. this is useful to avoid
|
// refreshRulesMap rebuilds the rule map from the kernel. This removes stale entries
|
||||||
// duplicates and to get missing attributes that we don't have when adding new rules
|
// (e.g. from failed flushes) and updates handles for all existing rules.
|
||||||
func (r *router) refreshRulesMap() error {
|
func (r *router) refreshRulesMap() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
newRules := make(map[string]*nftables.Rule)
|
||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("list rules: %w", err)
|
merr = multierror.Append(merr, fmt.Errorf("list rules for chain %s: %w", chain.Name, err))
|
||||||
|
// preserve existing entries for this chain since we can't verify their state
|
||||||
|
for k, v := range r.rules {
|
||||||
|
if v.Chain != nil && v.Chain.Name == chain.Name {
|
||||||
|
newRules[k] = v
|
||||||
|
}
|
||||||
|
}
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
r.rules[string(rule.UserData)] = rule
|
newRules[string(rule.UserData)] = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return nil
|
r.rules = newRules
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
@@ -1629,20 +1689,34 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
var needsFlush bool
|
||||||
|
|
||||||
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(dnatRule); err != nil {
|
if dnatRule.Handle == 0 {
|
||||||
|
log.Warnf("dnat rule %s has no handle, removing stale entry", ruleKey+dnatSuffix)
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
if err := r.conn.DelRule(masqRule); err != nil {
|
if masqRule.Handle == 0 {
|
||||||
|
log.Warnf("snat rule %s has no handle, removing stale entry", ruleKey+snatSuffix)
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
} else if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
|
} else {
|
||||||
|
needsFlush = true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if needsFlush {
|
||||||
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if merr == nil {
|
if merr == nil {
|
||||||
@@ -1757,16 +1831,25 @@ func (r *router) RemoveInboundDNAT(localAddr netip.Addr, protocol firewall.Proto
|
|||||||
|
|
||||||
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
ruleID := fmt.Sprintf("inbound-dnat-%s-%s-%d-%d", localAddr.String(), protocol, sourcePort, targetPort)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleID]; exists {
|
rule, exists := r.rules[ruleID]
|
||||||
if err := r.conn.DelRule(rule); err != nil {
|
if !exists {
|
||||||
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
return nil
|
||||||
}
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
|
||||||
}
|
|
||||||
delete(r.rules, ruleID)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if rule.Handle == 0 {
|
||||||
|
log.Warnf("inbound DNAT rule %s has no handle, removing stale entry", ruleID)
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
return fmt.Errorf("delete inbound DNAT rule %s: %w", ruleID, err)
|
||||||
|
}
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush delete inbound DNAT rule: %w", err)
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleID)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/test"
|
"github.com/netbirdio/netbird/client/firewall/test"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -719,3 +720,137 @@ func deleteWorkTable() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouter_RefreshRulesMap_RemovesStaleEntries(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Add a real rule to the kernel
|
||||||
|
ruleKey, err := r.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
firewall.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
|
firewall.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&firewall.Port{Values: []uint16{80}},
|
||||||
|
firewall.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, r.DeleteRouteRule(ruleKey))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0 (simulates store-before-flush failure)
|
||||||
|
staleKey := "stale-rule-that-does-not-exist"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Contains(t, r.rules, staleKey, "stale entry should be in map before refresh")
|
||||||
|
|
||||||
|
err = r.refreshRulesMap()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be removed after refresh")
|
||||||
|
|
||||||
|
realRule, ok := r.rules[ruleKey.ID()]
|
||||||
|
assert.True(t, ok, "real rule should still exist after refresh")
|
||||||
|
assert.NotZero(t, realRule.Handle, "real rule should have a valid handle")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_DeleteRouteRule_StaleHandle(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
workTable, err := createWorkTable()
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer deleteWorkTable()
|
||||||
|
|
||||||
|
r, err := newRouter(workTable, ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, r.init(workTable))
|
||||||
|
defer func() { require.NoError(t, r.Reset()) }()
|
||||||
|
|
||||||
|
// Inject a stale entry with Handle=0
|
||||||
|
staleKey := "stale-route-rule"
|
||||||
|
r.rules[staleKey] = &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
|
Handle: 0,
|
||||||
|
UserData: []byte(staleKey),
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteRouteRule should not return an error for stale handles
|
||||||
|
err = r.DeleteRouteRule(id.RuleID(staleKey))
|
||||||
|
assert.NoError(t, err, "deleting a stale rule should not error")
|
||||||
|
assert.NotContains(t, r.rules, staleKey, "stale entry should be cleaned up")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRouter_AddNatRule_WithStaleEntry(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
pair := firewall.RouterPair{
|
||||||
|
ID: "staletest",
|
||||||
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
rtr := manager.router
|
||||||
|
|
||||||
|
// First add succeeds
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, rtr.RemoveNatRule(pair))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Corrupt the handle to simulate stale state
|
||||||
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
if rule, exists := rtr.rules[natRuleKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
inverseKey := firewall.GenKey(firewall.PreroutingFormat, firewall.GetInversePair(pair))
|
||||||
|
if rule, exists := rtr.rules[inverseKey]; exists {
|
||||||
|
rule.Handle = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adding the same rule again should succeed despite stale handles
|
||||||
|
err = rtr.AddNatRule(pair)
|
||||||
|
assert.NoError(t, err, "AddNatRule should succeed even with stale entries")
|
||||||
|
|
||||||
|
// Verify rules exist in kernel
|
||||||
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
found := 0
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
|
found++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Equal(t, 1, found, "NAT rule should exist in kernel")
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,12 +3,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"net/netip"
|
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -17,33 +11,7 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Close(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
|
|||||||
@@ -1,12 +1,9 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -26,33 +23,7 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
m.resetState()
|
||||||
m.incomingDenyRules = make(map[netip.Addr]RuleSet)
|
|
||||||
m.incomingRules = make(map[netip.Addr]RuleSet)
|
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
|
||||||
m.udpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
|
||||||
m.icmpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
|
||||||
m.tcpTracker.Close()
|
|
||||||
}
|
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
|
||||||
fwder.Stop()
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.logger != nil {
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := m.logger.Stop(ctx); err != nil {
|
|
||||||
log.Errorf("failed to shutdown logger: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -115,6 +115,17 @@ func (t *TCPConnTrack) IsTombstone() bool {
|
|||||||
return t.tombstone.Load()
|
return t.tombstone.Load()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// IsSupersededBy returns true if this connection should be replaced by a new one
|
||||||
|
// carrying the given flags. Tombstoned connections are always superseded; TIME-WAIT
|
||||||
|
// connections are superseded by a pure SYN (a new connection attempt for the same
|
||||||
|
// four-tuple, as contemplated by RFC 1122 §4.2.2.13 and RFC 6191).
|
||||||
|
func (t *TCPConnTrack) IsSupersededBy(flags uint8) bool {
|
||||||
|
if t.tombstone.Load() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck == 0 && TCPState(t.state.Load()) == TCPStateTimeWait
|
||||||
|
}
|
||||||
|
|
||||||
// SetTombstone safely marks the connection for deletion
|
// SetTombstone safely marks the connection for deletion
|
||||||
func (t *TCPConnTrack) SetTombstone() {
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
t.tombstone.Store(true)
|
t.tombstone.Store(true)
|
||||||
@@ -169,7 +180,7 @@ func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists && !conn.IsSupersededBy(flags) {
|
||||||
t.updateState(key, conn, flags, direction, size)
|
t.updateState(key, conn, flags, direction, size)
|
||||||
return key, uint16(conn.DNATOrigPort.Load()), true
|
return key, uint16(conn.DNATOrigPort.Load()), true
|
||||||
}
|
}
|
||||||
@@ -241,7 +252,7 @@ func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort ui
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists || conn.IsTombstone() {
|
if !exists || conn.IsSupersededBy(flags) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -485,6 +485,261 @@ func TestTCPAbnormalSequences(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestTCPPortReuseTombstone verifies that a new connection on a port with a
|
||||||
|
// tombstoned (closed) conntrack entry is properly tracked. Without the fix,
|
||||||
|
// updateIfExists treats tombstoned entries as live, causing track() to skip
|
||||||
|
// creating a new connection. The subsequent SYN-ACK then fails IsValidInbound
|
||||||
|
// because the entry is tombstoned, and the response packet gets dropped by ACL.
|
||||||
|
func TestTCPPortReuseTombstone(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after graceful close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and gracefully close a connection (server-initiated close)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Server sends FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Client sends FIN-ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
|
// Server sends final ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Connection should be tombstoned
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn, "old connection should still be in map")
|
||||||
|
require.True(t, conn.IsTombstone(), "old connection should be tombstoned")
|
||||||
|
|
||||||
|
// Now reuse the same port for a new connection
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// The old tombstoned entry should be replaced with a new one
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
// SYN-ACK for the new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection on reused port should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
|
||||||
|
// Data transfer should work
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 500)
|
||||||
|
require.True(t, valid, "data should be allowed on new connection")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse after RST", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and RST a connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone(), "RST connection should be tombstoned")
|
||||||
|
|
||||||
|
// Reuse the same port
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState())
|
||||||
|
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK should be accepted after RST tombstone")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound port reuse after close", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
clientIP := srcIP
|
||||||
|
serverIP := dstIP
|
||||||
|
clientPort := srcPort
|
||||||
|
serverPort := dstPort
|
||||||
|
key := ConnKey{SrcIP: clientIP, DstIP: serverIP, SrcPort: clientPort, DstPort: serverPort}
|
||||||
|
|
||||||
|
// Inbound connection: client SYN → server SYN-ACK → client ACK
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Server-initiated close to reach Closed/tombstoned:
|
||||||
|
// Server FIN (opposite dir) → CloseWait
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPFin|TCPAck, 100)
|
||||||
|
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||||
|
// Client FIN-ACK (same dir as conn) → LastAck
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPFin|TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||||
|
// Server final ACK (opposite dir) → Closed → tombstoned
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||||
|
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// New inbound connection on same ports
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn)
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
|
||||||
|
// Complete handshake: server SYN-ACK, then client ACK
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100, 0)
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late ACK on tombstoned connection is harmless", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and close via passive close (server-initiated FIN → Closed → tombstoned)
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0) // CloseWait
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0) // LastAck
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0) // Closed
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
|
||||||
|
// Late ACK should be rejected (tombstoned)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.False(t, valid, "late ACK on tombstoned connection should be rejected")
|
||||||
|
|
||||||
|
// Late outbound ACK should not create a new connection (not a SYN)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.True(t, tracker.connections[key].IsTombstone(), "late outbound ACK should not replace tombstoned entry")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPPortReuseTimeWait(t *testing.T) {
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Outbound port reuse during TIME-WAIT (active close)", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Active close: client (outbound initiator) sends FIN first
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Server ACKs the FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
|
||||||
|
// Server sends its own FIN
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Client sends final ACK (TIME-WAIT stays, not tombstoned)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
require.False(t, conn.IsTombstone(), "TIME-WAIT should not be tombstoned")
|
||||||
|
|
||||||
|
// New outbound SYN on the same port (port reuse during TIME-WAIT)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
|
// Per RFC 1122/6191, new SYN during TIME-WAIT should start a new connection
|
||||||
|
newConn := tracker.connections[key]
|
||||||
|
require.NotNil(t, newConn, "new connection should exist")
|
||||||
|
require.False(t, newConn.IsTombstone(), "new connection should not be tombstoned")
|
||||||
|
require.Equal(t, TCPStateSynSent, newConn.GetState(), "new connection should be in SYN-SENT")
|
||||||
|
|
||||||
|
// SYN-ACK for new connection should be valid
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
|
require.True(t, valid, "SYN-ACK for new connection should be accepted")
|
||||||
|
require.Equal(t, TCPStateEstablished, newConn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Inbound SYN during TIME-WAIT falls through to normal tracking", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish outbound connection and close via active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Inbound SYN on same ports during TIME-WAIT: IsValidInbound returns false
|
||||||
|
// so the filter falls through to ACL check + TrackInbound (which creates
|
||||||
|
// a new connection via track() → updateIfExists skips TIME-WAIT for SYN)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, 0)
|
||||||
|
require.False(t, valid, "inbound SYN during TIME-WAIT should fail conntrack validation")
|
||||||
|
|
||||||
|
// Simulate what the filter does next: TrackInbound via the normal path
|
||||||
|
tracker.TrackInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn, nil, 100, 0)
|
||||||
|
|
||||||
|
// The new inbound connection uses the inverted key (dst→src becomes src→dst in track)
|
||||||
|
invertedKey := ConnKey{SrcIP: dstIP, DstIP: srcIP, SrcPort: dstPort, DstPort: srcPort}
|
||||||
|
newConn := tracker.connections[invertedKey]
|
||||||
|
require.NotNil(t, newConn, "new inbound connection should be tracked")
|
||||||
|
require.Equal(t, TCPStateSynReceived, newConn.GetState())
|
||||||
|
require.False(t, newConn.IsTombstone())
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Late retransmit during TIME-WAIT still allowed", func(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
key := ConnKey{SrcIP: srcIP, DstIP: dstIP, SrcPort: srcPort, DstPort: dstPort}
|
||||||
|
|
||||||
|
// Establish and active close → TIME-WAIT
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Late ACK retransmits during TIME-WAIT should still be accepted
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid, "retransmitted ACK during TIME-WAIT should be accepted")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func TestTCPTimeoutHandling(t *testing.T) {
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
// Create tracker with a very short timeout for testing
|
// Create tracker with a very short timeout for testing
|
||||||
shortTimeout := 100 * time.Millisecond
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -12,11 +13,13 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
@@ -24,6 +27,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
@@ -89,6 +93,7 @@ type Manager struct {
|
|||||||
incomingDenyRules map[netip.Addr]RuleSet
|
incomingDenyRules map[netip.Addr]RuleSet
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
|
routeRulesMap map[nbid.RuleID]*RouteRule
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -229,6 +234,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
routeRulesMap: make(map[nbid.RuleID]*RouteRule),
|
||||||
dnatMappings: make(map[netip.Addr]netip.Addr),
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
portDNATRules: []portDNATRule{},
|
portDNATRules: []portDNATRule{},
|
||||||
netstackServices: make(map[serviceKey]struct{}),
|
netstackServices: make(map[serviceKey]struct{}),
|
||||||
@@ -480,11 +486,15 @@ func (m *Manager) addRouteFiltering(
|
|||||||
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
|
if existingRule, ok := m.routeRulesMap[ruleKey]; ok {
|
||||||
|
return existingRule, nil
|
||||||
|
}
|
||||||
|
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
// TODO: consolidate these IDs
|
// TODO: consolidate these IDs
|
||||||
id: ruleID,
|
id: string(ruleKey),
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
dstSet: destination.Set,
|
dstSet: destination.Set,
|
||||||
@@ -499,6 +509,7 @@ func (m *Manager) addRouteFiltering(
|
|||||||
|
|
||||||
m.routeRules = append(m.routeRules, &rule)
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
|
m.routeRulesMap[ruleKey] = &rule
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
@@ -515,15 +526,20 @@ func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
|||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := rule.ID()
|
ruleKey := nbid.RuleID(rule.ID())
|
||||||
|
if _, ok := m.routeRulesMap[ruleKey]; !ok {
|
||||||
|
return fmt.Errorf("route rule not found: %s", ruleKey)
|
||||||
|
}
|
||||||
|
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||||
return r.id == ruleID
|
return r.id == string(ruleKey)
|
||||||
})
|
})
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
return fmt.Errorf("route rule not found: %s", ruleID)
|
return fmt.Errorf("route rule not found in slice: %s", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
m.routeRules = slices.Delete(m.routeRules, idx, idx+1)
|
||||||
|
delete(m.routeRulesMap, ruleKey)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -570,6 +586,40 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// resetState clears all firewall rules and closes connection trackers.
|
||||||
|
// Must be called with m.mutex held.
|
||||||
|
func (m *Manager) resetState() {
|
||||||
|
maps.Clear(m.outgoingRules)
|
||||||
|
maps.Clear(m.incomingDenyRules)
|
||||||
|
maps.Clear(m.incomingRules)
|
||||||
|
maps.Clear(m.routeRulesMap)
|
||||||
|
m.routeRules = m.routeRules[:0]
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
fwder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
// SetupEBPFProxyNoTrack creates notrack rules for eBPF proxy loopback traffic.
|
||||||
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
func (m *Manager) SetupEBPFProxyNoTrack(proxyPort, wgPort uint16) error {
|
||||||
if m.nativeFirewall == nil {
|
if m.nativeFirewall == nil {
|
||||||
|
|||||||
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
376
client/firewall/uspfilter/filter_routeacl_test.go
Normal file
@@ -0,0 +1,376 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestAddRouteFilteringReturnsExistingRule verifies that adding the same route
|
||||||
|
// filtering rule twice returns the same rule ID (idempotent behavior).
|
||||||
|
func TestAddRouteFilteringReturnsExistingRule(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("100.64.1.0/24"),
|
||||||
|
netip.MustParsePrefix("100.64.2.0/24"),
|
||||||
|
}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add rule first time
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
// Add the same rule again
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
// These should be the same (idempotent) like nftables/iptables implementations
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Adding the same rule twice should return the same rule ID (idempotent)")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestAddRouteFilteringDifferentRulesGetDifferentIDs verifies that rules with
|
||||||
|
// different parameters get distinct IDs.
|
||||||
|
func TestAddRouteFilteringDifferentRulesGetDifferentIDs(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
|
||||||
|
// Add first rule
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add different rule (different destination)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-2"),
|
||||||
|
sources,
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("192.168.2.0/24")}, // Different!
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
assert.NotEqual(t, rule1.ID(), rule2.ID(),
|
||||||
|
"Different rules should have different IDs")
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 3, ruleCount, "Should have 3 rules (2 user rules + 1 block rule)")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleUpdateDoesNotCauseGap verifies that re-adding the same route
|
||||||
|
// rule during a network map update does not disrupt existing traffic.
|
||||||
|
func TestRouteRuleUpdateDoesNotCauseGap(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
require.True(t, pass, "Traffic should pass with rule in place")
|
||||||
|
|
||||||
|
// Re-add same rule (simulates network map update)
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Idempotent IDs mean rule1.ID() == rule2.ID(), so the ACL manager
|
||||||
|
// won't delete rule1 during cleanup. If IDs differed, deleting rule1
|
||||||
|
// would remove the only matching rule and cause a traffic gap.
|
||||||
|
if rule1.ID() != rule2.ID() {
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, passAfter := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.True(t, passAfter,
|
||||||
|
"Traffic should still pass after rule update - no gap should occur")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockInvalidRoutedIdempotent verifies that blockInvalidRouted creates
|
||||||
|
// exactly one drop rule for the WireGuard network prefix, and calling it again
|
||||||
|
// returns the same rule without duplicating.
|
||||||
|
func TestBlockInvalidRoutedIdempotent(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call blockInvalidRouted directly multiple times
|
||||||
|
rule1, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule1)
|
||||||
|
|
||||||
|
rule2, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule2)
|
||||||
|
|
||||||
|
rule3, err := manager.blockInvalidRouted(ifaceMock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule3)
|
||||||
|
|
||||||
|
// All should return the same rule
|
||||||
|
assert.Equal(t, rule1.ID(), rule2.ID(), "Second call should return same rule")
|
||||||
|
assert.Equal(t, rule2.ID(), rule3.ID(), "Third call should return same rule")
|
||||||
|
|
||||||
|
// Should have exactly 1 route rule
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount, "Should have exactly 1 block rule after 3 calls")
|
||||||
|
|
||||||
|
// Verify the rule blocks traffic to the WG network
|
||||||
|
srcIP := netip.MustParseAddr("10.0.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.50")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 80)
|
||||||
|
assert.False(t, pass, "Block rule should deny traffic to WG prefix")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestBlockRuleNotAccumulatedOnRepeatedEnableRouting verifies that calling
|
||||||
|
// EnableRouting multiple times (as happens on each route update) does not
|
||||||
|
// accumulate duplicate block rules in the routeRules slice.
|
||||||
|
func TestBlockRuleNotAccumulatedOnRepeatedEnableRouting(t *testing.T) {
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
// Call EnableRouting multiple times (simulating repeated route updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 1, ruleCount,
|
||||||
|
"Repeated EnableRouting should not accumulate block rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRouteRuleCountStableAcrossUpdates verifies that adding the same route
|
||||||
|
// rule multiple times does not create duplicate entries.
|
||||||
|
func TestRouteRuleCountStableAcrossUpdates(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Simulate 5 network map updates with the same route rule
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.mutex.RLock()
|
||||||
|
ruleCount := len(manager.routeRules)
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 2, ruleCount,
|
||||||
|
"Should have exactly 2 rules (1 user rule + 1 block rule) after 5 updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDeleteRouteRuleAfterIdempotentAdd verifies that deleting a route rule
|
||||||
|
// after adding it multiple times works correctly.
|
||||||
|
func TestDeleteRouteRuleAfterIdempotentAdd(t *testing.T) {
|
||||||
|
manager := setupTestManager(t)
|
||||||
|
|
||||||
|
sources := []netip.Prefix{netip.MustParsePrefix("100.64.1.0/24")}
|
||||||
|
destination := fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")}
|
||||||
|
|
||||||
|
// Add same rule twice
|
||||||
|
rule1, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := manager.AddRouteFiltering(
|
||||||
|
[]byte("policy-1"),
|
||||||
|
sources,
|
||||||
|
destination,
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
require.Equal(t, rule1.ID(), rule2.ID(), "Should return same rule ID")
|
||||||
|
|
||||||
|
// Delete using first reference
|
||||||
|
err = manager.DeleteRouteRule(rule1)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify traffic no longer passes
|
||||||
|
srcIP := netip.MustParseAddr("100.64.1.5")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.10")
|
||||||
|
_, pass := manager.routeACLsPass(srcIP, dstIP, layers.LayerTypeTCP, 12345, 443)
|
||||||
|
assert.False(t, pass, "Traffic should not pass after rule deletion")
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupTestManager(t *testing.T) *Manager {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
|
wgNet := netip.MustParsePrefix("100.64.0.1/16")
|
||||||
|
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: wgNet.Addr(),
|
||||||
|
Network: wgNet,
|
||||||
|
}
|
||||||
|
},
|
||||||
|
GetDeviceFunc: func() *device.FilteredDevice {
|
||||||
|
return &device.FilteredDevice{Device: dev}
|
||||||
|
},
|
||||||
|
GetWGDeviceFunc: func() *wgdevice.Device {
|
||||||
|
return &wgdevice.Device{}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, manager.EnableRouting())
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
return manager
|
||||||
|
}
|
||||||
@@ -263,6 +263,158 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleLifecycleDenyRules verifies that deny rules are correctly added
|
||||||
|
// to the deny map and can be cleanly deleted without leaving orphans.
|
||||||
|
func TestPeerRuleLifecycleDenyRules(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Add multiple deny rules for different ports
|
||||||
|
rule1, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
rule2, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 2, denyCount, "Should have exactly 2 deny rules")
|
||||||
|
|
||||||
|
// Delete the first deny rule
|
||||||
|
err = m.DeletePeerRule(rule1[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount = len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule after deleting first")
|
||||||
|
|
||||||
|
// Delete the second deny rule
|
||||||
|
err = m.DeletePeerRule(rule2[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, exists := m.incomingDenyRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
require.False(t, exists, "Deny rules IP entry should be cleaned up when empty")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestPeerRuleAddAndDeleteDontLeak verifies that repeatedly adding and deleting
|
||||||
|
// peer rules (simulating network map updates) does not leak rules in the maps.
|
||||||
|
func TestPeerRuleAddAndDeleteDontLeak(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
|
||||||
|
// Simulate 10 network map updates: add rule, delete old, add new
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
// Add a deny rule
|
||||||
|
rules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add an allow rule
|
||||||
|
allowRules, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Delete them (simulating ACL manager cleanup)
|
||||||
|
for _, r := range rules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
for _, r := range allowRules {
|
||||||
|
require.NoError(t, m.DeletePeerRule(r))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 0, denyCount, "No deny rules should remain after cleanup")
|
||||||
|
require.Equal(t, 0, allowCount, "No allow rules should remain after cleanup")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMixedAllowDenyRulesSameIP verifies that allow and deny rules for the same
|
||||||
|
// IP are stored in separate maps and don't interfere with each other.
|
||||||
|
func TestMixedAllowDenyRulesSameIP(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger, nbiface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, m.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
ip := net.ParseIP("192.168.1.1")
|
||||||
|
|
||||||
|
// Add allow rule for port 80
|
||||||
|
allowRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Add deny rule for port 22
|
||||||
|
denyRule, err := m.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil,
|
||||||
|
&fw.Port{Values: []uint16{22}}, fw.ActionDrop, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
addr := netip.MustParseAddr("192.168.1.1")
|
||||||
|
m.mutex.RLock()
|
||||||
|
allowCount := len(m.incomingRules[addr])
|
||||||
|
denyCount := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, allowCount, "Should have 1 allow rule")
|
||||||
|
require.Equal(t, 1, denyCount, "Should have 1 deny rule")
|
||||||
|
|
||||||
|
// Delete allow rule should not affect deny rule
|
||||||
|
err = m.DeletePeerRule(allowRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
denyCountAfter := len(m.incomingDenyRules[addr])
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.Equal(t, 1, denyCountAfter, "Deny rule should still exist after deleting allow rule")
|
||||||
|
|
||||||
|
// Delete deny rule
|
||||||
|
err = m.DeletePeerRule(denyRule[0])
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
m.mutex.RLock()
|
||||||
|
_, denyExists := m.incomingDenyRules[addr]
|
||||||
|
_, allowExists := m.incomingRules[addr]
|
||||||
|
m.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.False(t, denyExists, "Deny rules should be empty")
|
||||||
|
require.False(t, allowExists, "Allow rules should be empty")
|
||||||
|
}
|
||||||
|
|
||||||
func TestManagerReset(t *testing.T) {
|
func TestManagerReset(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
@@ -16,9 +18,18 @@ const (
|
|||||||
maxBatchSize = 1024 * 16
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2
|
maxMessageSize = 1024 * 2
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
logChannelSize = 1000
|
defaultLogChanSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func getLogChannelSize() int {
|
||||||
|
if v := os.Getenv("NB_USPFILTER_LOG_BUFFER"); v != "" {
|
||||||
|
if n, err := strconv.Atoi(v); err == nil && n > 0 {
|
||||||
|
return n
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return defaultLogChanSize
|
||||||
|
}
|
||||||
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -69,7 +80,7 @@ type Logger struct {
|
|||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
msgChannel: make(chan logMessage, logChannelSize),
|
msgChannel: make(chan logMessage, getLogChannelSize()),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
|
|||||||
@@ -358,9 +358,9 @@ func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
|||||||
// Fast path for IPv4 addresses (4 bytes) - most common case
|
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||||
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||||
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4])) //nolint:gosec // length checked above
|
||||||
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||||
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4])) //nolint:gosec // length checked above
|
||||||
} else {
|
} else {
|
||||||
// Fallback for other lengths
|
// Fallback for other lengths
|
||||||
for i := 0; i < len(oldBytes)-1; i += 2 {
|
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||||
|
|||||||
@@ -28,7 +28,7 @@ func Backoff(ctx context.Context) backoff.BackOff {
|
|||||||
|
|
||||||
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
// CreateConnection creates a gRPC client connection with the appropriate transport options.
|
||||||
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
// The component parameter specifies the WebSocket proxy component path (e.g., "/management", "/signal").
|
||||||
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string) (*grpc.ClientConn, error) {
|
func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, component string, extraOpts ...grpc.DialOption) (*grpc.ClientConn, error) {
|
||||||
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
transportOption := grpc.WithTransportCredentials(insecure.NewCredentials())
|
||||||
// for js, the outer websocket layer takes care of tls
|
// for js, the outer websocket layer takes care of tls
|
||||||
if tlsEnabled && runtime.GOOS != "js" {
|
if tlsEnabled && runtime.GOOS != "js" {
|
||||||
@@ -46,9 +46,7 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
connCtx, cancel := context.WithTimeout(ctx, 30*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
conn, err := grpc.DialContext(
|
opts := []grpc.DialOption{
|
||||||
connCtx,
|
|
||||||
addr,
|
|
||||||
transportOption,
|
transportOption,
|
||||||
WithCustomDialer(tlsEnabled, component),
|
WithCustomDialer(tlsEnabled, component),
|
||||||
grpc.WithBlock(),
|
grpc.WithBlock(),
|
||||||
@@ -56,7 +54,10 @@ func CreateConnection(ctx context.Context, addr string, tlsEnabled bool, compone
|
|||||||
Time: 30 * time.Second,
|
Time: 30 * time.Second,
|
||||||
Timeout: 10 * time.Second,
|
Timeout: 10 * time.Second,
|
||||||
}),
|
}),
|
||||||
)
|
}
|
||||||
|
opts = append(opts, extraOpts...)
|
||||||
|
|
||||||
|
conn, err := grpc.DialContext(connCtx, addr, opts...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("dial context: %w", err)
|
return nil, fmt.Errorf("dial context: %w", err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,20 +5,18 @@ package configurer
|
|||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"golang.zx2c4.com/wireguard/ipc"
|
"golang.zx2c4.com/wireguard/ipc"
|
||||||
)
|
)
|
||||||
|
|
||||||
func openUAPI(deviceName string) (net.Listener, error) {
|
func openUAPI(deviceName string) (net.Listener, error) {
|
||||||
uapiSock, err := ipc.UAPIOpen(deviceName)
|
uapiSock, err := ipc.UAPIOpen(deviceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to open uapi socket: %v", err)
|
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
listener, err := ipc.UAPIListen(deviceName, uapiSock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to listen on uapi socket: %v", err)
|
_ = uapiSock.Close()
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -54,6 +54,14 @@ func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder
|
|||||||
return wgCfg
|
return wgCfg
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func NewUSPConfigurerNoUAPI(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||||
|
return &WGUSPConfigurer{
|
||||||
|
device: device,
|
||||||
|
deviceName: deviceName,
|
||||||
|
activityRecorder: activityRecorder,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error {
|
||||||
log.Debugf("adding Wireguard private key")
|
log.Debugf("adding Wireguard private key")
|
||||||
key, err := wgtypes.ParseKey(privateKey)
|
key, err := wgtypes.ParseKey(privateKey)
|
||||||
|
|||||||
@@ -29,8 +29,9 @@ type PacketFilter interface {
|
|||||||
type FilteredDevice struct {
|
type FilteredDevice struct {
|
||||||
tun.Device
|
tun.Device
|
||||||
|
|
||||||
filter PacketFilter
|
filter PacketFilter
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
|
closeOnce sync.Once
|
||||||
}
|
}
|
||||||
|
|
||||||
// newDeviceFilter constructor function
|
// newDeviceFilter constructor function
|
||||||
@@ -40,6 +41,20 @@ func newDeviceFilter(device tun.Device) *FilteredDevice {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Close closes the underlying tun device exactly once.
|
||||||
|
// wireguard-go's netTun.Close() panics on double-close due to a bare close(channel),
|
||||||
|
// and multiple code paths can trigger Close on the same device.
|
||||||
|
func (d *FilteredDevice) Close() error {
|
||||||
|
var err error
|
||||||
|
d.closeOnce.Do(func() {
|
||||||
|
err = d.Device.Close()
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Read wraps read method with filtering feature
|
// Read wraps read method with filtering feature
|
||||||
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, err error) {
|
||||||
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
if n, err = d.Device.Read(bufs, sizes, offset); err != nil {
|
||||||
|
|||||||
@@ -79,10 +79,12 @@ func (t *TunNetstackDevice) create() (WGConfigurer, error) {
|
|||||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||||
)
|
)
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.bind.ActivityRecorder())
|
t.configurer = configurer.NewUSPConfigurerNoUAPI(t.device, t.name, t.bind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
if cErr := tunIface.Close(); cErr != nil {
|
||||||
|
log.Debugf("failed to close tun device: %v", cErr)
|
||||||
|
}
|
||||||
return nil, fmt.Errorf("error configuring interface: %s", err)
|
return nil, fmt.Errorf("error configuring interface: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/errors"
|
"github.com/netbirdio/netbird/client/errors"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
"github.com/netbirdio/netbird/client/iface/wgproxy"
|
||||||
@@ -228,6 +229,10 @@ func (w *WGIface) Close() error {
|
|||||||
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
result = multierror.Append(result, fmt.Errorf("failed to close wireguard interface %s: %w", w.Name(), err))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if nbnetstack.IsEnabled() {
|
||||||
|
return errors.FormatErrorOrNil(result)
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.waitUntilRemoved(); err != nil {
|
if err := w.waitUntilRemoved(); err != nil {
|
||||||
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
log.Warnf("failed to remove WireGuard interface %s: %v", w.Name(), err)
|
||||||
if err := w.Destroy(); err != nil {
|
if err := w.Destroy(); err != nil {
|
||||||
|
|||||||
@@ -66,7 +66,7 @@ func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
return nsTunDev, tunNet, nil
|
return t.tundev, tunNet, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Close() error {
|
func (t *NetStackTun) Close() error {
|
||||||
|
|||||||
@@ -189,6 +189,212 @@ func TestDefaultManagerStateless(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesNotAccumulatedOnRepeatedApply verifies that applying the same
|
||||||
|
// deny rules repeatedly does not accumulate duplicate rules in the uspfilter.
|
||||||
|
// This tests the full ACL manager -> uspfilter integration.
|
||||||
|
func TestDenyRulesNotAccumulatedOnRepeatedApply(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "80",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// Apply the same rules 5 times (simulating repeated network map updates)
|
||||||
|
for i := 0; i < 5; i++ {
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
}
|
||||||
|
|
||||||
|
// The ACL manager should track exactly 3 rule pairs (2 deny + 1 accept inbound)
|
||||||
|
assert.Equal(t, 3, len(acl.peerRulesPairs),
|
||||||
|
"Should have exactly 3 rule pairs after 5 identical updates")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDenyRulesCleanedUpOnRemoval verifies that deny rules are properly cleaned
|
||||||
|
// up when they're removed from the network map in a subsequent update.
|
||||||
|
func TestDenyRulesCleanedUpOnRemoval(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: add deny and accept rules
|
||||||
|
networkMap1 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap1, false)
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs), "Should have 2 rules after first update")
|
||||||
|
|
||||||
|
// Second update: remove the deny rule, keep only accept
|
||||||
|
networkMap2 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap2, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Should have 1 rule after removing deny rule")
|
||||||
|
|
||||||
|
// Third update: remove all rules
|
||||||
|
networkMap3 := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{},
|
||||||
|
FirewallRulesIsEmpty: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
acl.ApplyFiltering(networkMap3, false)
|
||||||
|
assert.Equal(t, 0, len(acl.peerRulesPairs),
|
||||||
|
"Should have 0 rules after removing all rules")
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestRuleUpdateChangingAction verifies that when a rule's action changes from
|
||||||
|
// accept to deny (or vice versa), the old rule is properly removed and the new
|
||||||
|
// one added without leaking.
|
||||||
|
func TestRuleUpdateChangingAction(t *testing.T) {
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false, iface.DefaultMTU)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, fw.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
// First update: accept rule
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
FirewallRulesIsEmpty: false,
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs))
|
||||||
|
|
||||||
|
// Second update: change to deny (same IP/port/proto, different action)
|
||||||
|
networkMap.FirewallRules = []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "22",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
|
// Should still have exactly 1 rule (the old accept removed, new deny added)
|
||||||
|
assert.Equal(t, 1, len(acl.peerRulesPairs),
|
||||||
|
"Changing action should result in exactly 1 rule, not 2")
|
||||||
|
}
|
||||||
|
|
||||||
func TestPortInfoEmpty(t *testing.T) {
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
|
|||||||
@@ -221,7 +221,7 @@ func (a *Auth) getPKCEFlow(client *mgm.GrpcClient) (*PKCEAuthorizationFlow, erro
|
|||||||
config := &PKCEAuthProviderConfig{
|
config := &PKCEAuthProviderConfig{
|
||||||
Audience: protoConfig.GetAudience(),
|
Audience: protoConfig.GetAudience(),
|
||||||
ClientID: protoConfig.GetClientID(),
|
ClientID: protoConfig.GetClientID(),
|
||||||
ClientSecret: protoConfig.GetClientSecret(),
|
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
AuthorizationEndpoint: protoConfig.GetAuthorizationEndpoint(),
|
||||||
Scope: protoConfig.GetScope(),
|
Scope: protoConfig.GetScope(),
|
||||||
@@ -266,7 +266,7 @@ func (a *Auth) getDeviceFlow(client *mgm.GrpcClient) (*DeviceAuthorizationFlow,
|
|||||||
config := &DeviceAuthProviderConfig{
|
config := &DeviceAuthProviderConfig{
|
||||||
Audience: protoConfig.GetAudience(),
|
Audience: protoConfig.GetAudience(),
|
||||||
ClientID: protoConfig.GetClientID(),
|
ClientID: protoConfig.GetClientID(),
|
||||||
ClientSecret: protoConfig.GetClientSecret(),
|
ClientSecret: protoConfig.GetClientSecret(), //nolint:staticcheck
|
||||||
Domain: protoConfig.Domain,
|
Domain: protoConfig.Domain,
|
||||||
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
TokenEndpoint: protoConfig.GetTokenEndpoint(),
|
||||||
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
DeviceAuthEndpoint: protoConfig.GetDeviceAuthEndpoint(),
|
||||||
|
|||||||
@@ -20,14 +20,16 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
"github.com/netbirdio/netbird/client/internal/updater"
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
nbnet "github.com/netbirdio/netbird/client/net"
|
nbnet "github.com/netbirdio/netbird/client/net"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
@@ -42,14 +44,19 @@ import (
|
|||||||
"github.com/netbirdio/netbird/version"
|
"github.com/netbirdio/netbird/version"
|
||||||
)
|
)
|
||||||
|
|
||||||
type ConnectClient struct {
|
// androidRunOverride is set on Android to inject mobile dependencies
|
||||||
ctx context.Context
|
// when using embed.Client (which calls Run() with empty MobileDependency).
|
||||||
config *profilemanager.Config
|
var androidRunOverride func(c *ConnectClient, runningChan chan struct{}, logPath string) error
|
||||||
statusRecorder *peer.Status
|
|
||||||
doInitialAutoUpdate bool
|
|
||||||
|
|
||||||
engine *Engine
|
type ConnectClient struct {
|
||||||
engineMutex sync.Mutex
|
ctx context.Context
|
||||||
|
config *profilemanager.Config
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
|
engine *Engine
|
||||||
|
engineMutex sync.Mutex
|
||||||
|
clientMetrics *metrics.ClientMetrics
|
||||||
|
updateManager *updater.Manager
|
||||||
|
|
||||||
persistSyncResponse bool
|
persistSyncResponse bool
|
||||||
}
|
}
|
||||||
@@ -58,19 +65,24 @@ func NewConnectClient(
|
|||||||
ctx context.Context,
|
ctx context.Context,
|
||||||
config *profilemanager.Config,
|
config *profilemanager.Config,
|
||||||
statusRecorder *peer.Status,
|
statusRecorder *peer.Status,
|
||||||
doInitalAutoUpdate bool,
|
|
||||||
) *ConnectClient {
|
) *ConnectClient {
|
||||||
return &ConnectClient{
|
return &ConnectClient{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
config: config,
|
config: config,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
doInitialAutoUpdate: doInitalAutoUpdate,
|
engineMutex: sync.Mutex{},
|
||||||
engineMutex: sync.Mutex{},
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *ConnectClient) SetUpdateManager(um *updater.Manager) {
|
||||||
|
c.updateManager = um
|
||||||
|
}
|
||||||
|
|
||||||
// Run with main logic.
|
// Run with main logic.
|
||||||
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
func (c *ConnectClient) Run(runningChan chan struct{}, logPath string) error {
|
||||||
|
if androidRunOverride != nil {
|
||||||
|
return androidRunOverride(c, runningChan, logPath)
|
||||||
|
}
|
||||||
return c.run(MobileDependency{}, runningChan, logPath)
|
return c.run(MobileDependency{}, runningChan, logPath)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -130,10 +142,34 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// Stop metrics push on exit
|
||||||
|
defer func() {
|
||||||
|
if c.clientMetrics != nil {
|
||||||
|
c.clientMetrics.StopPush()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
log.Infof("starting NetBird client version %s on %s/%s", version.NetbirdVersion(), runtime.GOOS, runtime.GOARCH)
|
||||||
|
|
||||||
nbnet.Init()
|
nbnet.Init()
|
||||||
|
|
||||||
|
// Initialize metrics once at startup (always active for debug bundles)
|
||||||
|
if c.clientMetrics == nil {
|
||||||
|
agentInfo := metrics.AgentInfo{
|
||||||
|
DeploymentType: metrics.DeploymentTypeUnknown,
|
||||||
|
Version: version.NetbirdVersion(),
|
||||||
|
OS: runtime.GOOS,
|
||||||
|
Arch: runtime.GOARCH,
|
||||||
|
}
|
||||||
|
c.clientMetrics = metrics.NewClientMetrics(agentInfo)
|
||||||
|
log.Debugf("initialized client metrics")
|
||||||
|
|
||||||
|
// Start metrics push if enabled (uses daemon context, persists across engine restarts)
|
||||||
|
if metrics.IsMetricsPushEnabled() {
|
||||||
|
c.clientMetrics.StartPush(c.ctx, metrics.PushConfigFromEnv())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
backOff := &backoff.ExponentialBackOff{
|
backOff := &backoff.ExponentialBackOff{
|
||||||
InitialInterval: time.Second,
|
InitialInterval: time.Second,
|
||||||
RandomizationFactor: 1,
|
RandomizationFactor: 1,
|
||||||
@@ -186,14 +222,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
stateManager := statemanager.New(path)
|
stateManager := statemanager.New(path)
|
||||||
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
stateManager.RegisterState(&sshconfig.ShutdownState{})
|
||||||
|
|
||||||
updateManager, err := updatemanager.NewManager(c.statusRecorder, stateManager)
|
if c.updateManager != nil {
|
||||||
if err == nil {
|
c.updateManager.CheckUpdateSuccess(c.ctx)
|
||||||
updateManager.CheckUpdateSuccess(c.ctx)
|
}
|
||||||
|
|
||||||
inst := installer.New()
|
inst := installer.New()
|
||||||
if err := inst.CleanUpInstallerFiles(); err != nil {
|
if err := inst.CleanUpInstallerFiles(); err != nil {
|
||||||
log.Errorf("failed to clean up temporary installer file: %v", err)
|
log.Errorf("failed to clean up temporary installer file: %v", err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
defer c.statusRecorder.ClientStop()
|
defer c.statusRecorder.ClientStop()
|
||||||
@@ -221,6 +256,16 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
mgmNotifier := statusRecorderToMgmConnStateNotifier(c.statusRecorder)
|
||||||
mgmClient.SetConnStateListener(mgmNotifier)
|
mgmClient.SetConnStateListener(mgmNotifier)
|
||||||
|
|
||||||
|
// Update metrics with actual deployment type after connection
|
||||||
|
deploymentType := metrics.DetermineDeploymentType(mgmClient.GetServerURL())
|
||||||
|
agentInfo := metrics.AgentInfo{
|
||||||
|
DeploymentType: deploymentType,
|
||||||
|
Version: version.NetbirdVersion(),
|
||||||
|
OS: runtime.GOOS,
|
||||||
|
Arch: runtime.GOARCH,
|
||||||
|
}
|
||||||
|
c.clientMetrics.UpdateAgentInfo(agentInfo, myPrivateKey.PublicKey().String())
|
||||||
|
|
||||||
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
log.Debugf("connected to the Management service %s", c.config.ManagementURL.Host)
|
||||||
defer func() {
|
defer func() {
|
||||||
if err = mgmClient.Close(); err != nil {
|
if err = mgmClient.Close(); err != nil {
|
||||||
@@ -229,8 +274,10 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
// connect (just a connection, no stream yet) and login to Management Service to get an initial global Netbird config
|
||||||
|
loginStarted := time.Now()
|
||||||
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
loginResp, err := loginToManagement(engineCtx, mgmClient, publicSSHKey, c.config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), false)
|
||||||
log.Debug(err)
|
log.Debug(err)
|
||||||
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
if s, ok := gstatus.FromError(err); ok && (s.Code() == codes.PermissionDenied) {
|
||||||
state.Set(StatusNeedsLogin)
|
state.Set(StatusNeedsLogin)
|
||||||
@@ -239,12 +286,13 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
}
|
}
|
||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
c.clientMetrics.RecordLoginDuration(engineCtx, time.Since(loginStarted), true)
|
||||||
c.statusRecorder.MarkManagementConnected()
|
c.statusRecorder.MarkManagementConnected()
|
||||||
|
|
||||||
localPeerState := peer.LocalPeerState{
|
localPeerState := peer.LocalPeerState{
|
||||||
IP: loginResp.GetPeerConfig().GetAddress(),
|
IP: loginResp.GetPeerConfig().GetAddress(),
|
||||||
PubKey: myPrivateKey.PublicKey().String(),
|
PubKey: myPrivateKey.PublicKey().String(),
|
||||||
KernelInterface: device.WireGuardModuleIsLoaded(),
|
KernelInterface: device.WireGuardModuleIsLoaded() && !netstack.IsEnabled(),
|
||||||
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
FQDN: loginResp.GetPeerConfig().GetFqdn(),
|
||||||
}
|
}
|
||||||
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
c.statusRecorder.UpdateLocalPeerState(localPeerState)
|
||||||
@@ -307,7 +355,16 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
checks := loginResp.GetChecks()
|
checks := loginResp.GetChecks()
|
||||||
|
|
||||||
c.engineMutex.Lock()
|
c.engineMutex.Lock()
|
||||||
engine := NewEngine(engineCtx, cancel, signalClient, mgmClient, relayManager, engineConfig, mobileDependency, c.statusRecorder, checks, stateManager)
|
engine := NewEngine(engineCtx, cancel, engineConfig, EngineServices{
|
||||||
|
SignalClient: signalClient,
|
||||||
|
MgmClient: mgmClient,
|
||||||
|
RelayManager: relayManager,
|
||||||
|
StatusRecorder: c.statusRecorder,
|
||||||
|
Checks: checks,
|
||||||
|
StateManager: stateManager,
|
||||||
|
UpdateManager: c.updateManager,
|
||||||
|
ClientMetrics: c.clientMetrics,
|
||||||
|
}, mobileDependency)
|
||||||
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
engine.SetSyncResponsePersistence(c.persistSyncResponse)
|
||||||
c.engine = engine
|
c.engine = engine
|
||||||
c.engineMutex.Unlock()
|
c.engineMutex.Unlock()
|
||||||
@@ -317,21 +374,15 @@ func (c *ConnectClient) run(mobileDependency MobileDependency, runningChan chan
|
|||||||
return wrapErr(err)
|
return wrapErr(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if loginResp.PeerConfig != nil && loginResp.PeerConfig.AutoUpdate != nil {
|
|
||||||
// AutoUpdate will be true when the user click on "Connect" menu on the UI
|
|
||||||
if c.doInitialAutoUpdate {
|
|
||||||
log.Infof("start engine by ui, run auto-update check")
|
|
||||||
c.engine.InitialUpdateHandling(loginResp.PeerConfig.AutoUpdate)
|
|
||||||
c.doInitialAutoUpdate = false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
log.Infof("Netbird engine started, the IP is: %s", peerConfig.GetAddress())
|
||||||
state.Set(StatusConnected)
|
state.Set(StatusConnected)
|
||||||
|
|
||||||
if runningChan != nil {
|
if runningChan != nil {
|
||||||
close(runningChan)
|
select {
|
||||||
runningChan = nil
|
case <-runningChan:
|
||||||
|
default:
|
||||||
|
close(runningChan)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
<-engineCtx.Done()
|
<-engineCtx.Done()
|
||||||
|
|||||||
73
client/internal/connect_android_default.go
Normal file
73
client/internal/connect_android_default.go
Normal file
@@ -0,0 +1,73 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// noopIFaceDiscover is a stub ExternalIFaceDiscover for embed.Client on Android.
|
||||||
|
// It returns an empty interface list, which means ICE P2P candidates won't be
|
||||||
|
// discovered — connections will fall back to relay. Applications that need P2P
|
||||||
|
// should provide a real implementation via runOnAndroidEmbed that uses
|
||||||
|
// Android's ConnectivityManager to enumerate network interfaces.
|
||||||
|
type noopIFaceDiscover struct{}
|
||||||
|
|
||||||
|
func (noopIFaceDiscover) IFaces() (string, error) {
|
||||||
|
// Return empty JSON array — no local interfaces advertised for ICE.
|
||||||
|
// This is intentional: without Android's ConnectivityManager, we cannot
|
||||||
|
// reliably enumerate interfaces (netlink is restricted on Android 11+).
|
||||||
|
// Relay connections still work; only P2P hole-punching is disabled.
|
||||||
|
return "[]", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// noopNetworkChangeListener is a stub for embed.Client on Android.
|
||||||
|
// Network change events are ignored since the embed client manages its own
|
||||||
|
// reconnection logic via the engine's built-in retry mechanism.
|
||||||
|
type noopNetworkChangeListener struct{}
|
||||||
|
|
||||||
|
func (noopNetworkChangeListener) OnNetworkChanged(string) {
|
||||||
|
// No-op: embed.Client relies on the engine's internal reconnection
|
||||||
|
// logic rather than OS-level network change notifications.
|
||||||
|
}
|
||||||
|
|
||||||
|
func (noopNetworkChangeListener) SetInterfaceIP(string) {
|
||||||
|
// No-op: in netstack mode, the overlay IP is managed by the userspace
|
||||||
|
// network stack, not by OS-level interface configuration.
|
||||||
|
}
|
||||||
|
|
||||||
|
// noopDnsReadyListener is a stub for embed.Client on Android.
|
||||||
|
// DNS readiness notifications are not needed in netstack/embed mode
|
||||||
|
// since system DNS is disabled and DNS resolution happens externally.
|
||||||
|
type noopDnsReadyListener struct{}
|
||||||
|
|
||||||
|
func (noopDnsReadyListener) OnReady() {
|
||||||
|
// No-op: embed.Client does not need DNS readiness notifications.
|
||||||
|
// System DNS is disabled in netstack mode.
|
||||||
|
}
|
||||||
|
|
||||||
|
var _ stdnet.ExternalIFaceDiscover = noopIFaceDiscover{}
|
||||||
|
var _ listener.NetworkChangeListener = noopNetworkChangeListener{}
|
||||||
|
var _ dns.ReadyListener = noopDnsReadyListener{}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Wire up the default override so embed.Client.Start() works on Android
|
||||||
|
// with netstack mode. Provides complete no-op stubs for all mobile
|
||||||
|
// dependencies so the engine's existing Android code paths work unchanged.
|
||||||
|
// Applications that need P2P ICE or real DNS should replace this by
|
||||||
|
// setting androidRunOverride before calling Start().
|
||||||
|
androidRunOverride = func(c *ConnectClient, runningChan chan struct{}, logPath string) error {
|
||||||
|
return c.runOnAndroidEmbed(
|
||||||
|
noopIFaceDiscover{},
|
||||||
|
noopNetworkChangeListener{},
|
||||||
|
[]netip.AddrPort{},
|
||||||
|
noopDnsReadyListener{},
|
||||||
|
runningChan,
|
||||||
|
logPath,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
32
client/internal/connect_android_embed.go
Normal file
32
client/internal/connect_android_embed.go
Normal file
@@ -0,0 +1,32 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
|
)
|
||||||
|
|
||||||
|
// runOnAndroidEmbed is like RunOnAndroid but accepts a runningChan
|
||||||
|
// so embed.Client.Start() can detect when the engine is ready.
|
||||||
|
// It provides complete MobileDependency so the engine's existing
|
||||||
|
// Android code paths work unchanged.
|
||||||
|
func (c *ConnectClient) runOnAndroidEmbed(
|
||||||
|
iFaceDiscover stdnet.ExternalIFaceDiscover,
|
||||||
|
networkChangeListener listener.NetworkChangeListener,
|
||||||
|
dnsAddresses []netip.AddrPort,
|
||||||
|
dnsReadyListener dns.ReadyListener,
|
||||||
|
runningChan chan struct{},
|
||||||
|
logPath string,
|
||||||
|
) error {
|
||||||
|
mobileDependency := MobileDependency{
|
||||||
|
IFaceDiscover: iFaceDiscover,
|
||||||
|
NetworkChangeListener: networkChangeListener,
|
||||||
|
HostDNSAddresses: dnsAddresses,
|
||||||
|
DnsReadyListener: dnsReadyListener,
|
||||||
|
}
|
||||||
|
return c.run(mobileDependency, runningChan, logPath)
|
||||||
|
}
|
||||||
60
client/internal/daemonaddr/resolve.go
Normal file
60
client/internal/daemonaddr/resolve.go
Normal file
@@ -0,0 +1,60 @@
|
|||||||
|
//go:build !windows && !ios && !android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
var scanDir = "/var/run/netbird"
|
||||||
|
|
||||||
|
// setScanDir overrides the scan directory (used by tests).
|
||||||
|
func setScanDir(dir string) {
|
||||||
|
scanDir = dir
|
||||||
|
}
|
||||||
|
|
||||||
|
// ResolveUnixDaemonAddr checks whether the default Unix socket exists and, if not,
|
||||||
|
// scans /var/run/netbird/ for a single .sock file to use instead. This handles the
|
||||||
|
// mismatch between the netbird@.service template (which places the socket under
|
||||||
|
// /var/run/netbird/<instance>.sock) and the CLI default (/var/run/netbird.sock).
|
||||||
|
func ResolveUnixDaemonAddr(addr string) string {
|
||||||
|
if !strings.HasPrefix(addr, "unix://") {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
sockPath := strings.TrimPrefix(addr, "unix://")
|
||||||
|
if _, err := os.Stat(sockPath); err == nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
entries, err := os.ReadDir(scanDir)
|
||||||
|
if err != nil {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
var found []string
|
||||||
|
for _, e := range entries {
|
||||||
|
if e.IsDir() {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if strings.HasSuffix(e.Name(), ".sock") {
|
||||||
|
found = append(found, filepath.Join(scanDir, e.Name()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
switch len(found) {
|
||||||
|
case 1:
|
||||||
|
resolved := "unix://" + found[0]
|
||||||
|
log.Debugf("Default daemon socket not found, using discovered socket: %s", resolved)
|
||||||
|
return resolved
|
||||||
|
case 0:
|
||||||
|
return addr
|
||||||
|
default:
|
||||||
|
log.Warnf("Default daemon socket not found and multiple sockets discovered in %s; pass --daemon-addr explicitly", scanDir)
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
}
|
||||||
8
client/internal/daemonaddr/resolve_stub.go
Normal file
8
client/internal/daemonaddr/resolve_stub.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build windows || ios || android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
// ResolveUnixDaemonAddr is a no-op on platforms that don't use Unix sockets.
|
||||||
|
func ResolveUnixDaemonAddr(addr string) string {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
121
client/internal/daemonaddr/resolve_test.go
Normal file
121
client/internal/daemonaddr/resolve_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
//go:build !windows && !ios && !android
|
||||||
|
|
||||||
|
package daemonaddr
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
// createSockFile creates a regular file with a .sock extension.
|
||||||
|
// ResolveUnixDaemonAddr uses os.Stat (not net.Dial), so a regular file is
|
||||||
|
// sufficient and avoids Unix socket path-length limits on macOS.
|
||||||
|
func createSockFile(t *testing.T, path string) {
|
||||||
|
t.Helper()
|
||||||
|
if err := os.WriteFile(path, nil, 0o600); err != nil {
|
||||||
|
t.Fatalf("failed to create test sock file at %s: %v", path, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_DefaultExists(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
sock := filepath.Join(tmp, "netbird.sock")
|
||||||
|
createSockFile(t, sock)
|
||||||
|
|
||||||
|
addr := "unix://" + sock
|
||||||
|
got := ResolveUnixDaemonAddr(addr)
|
||||||
|
if got != addr {
|
||||||
|
t.Errorf("expected %s, got %s", addr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_SingleDiscovered(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
// Default socket does not exist
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
// Create a scan dir with one socket
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
instanceSock := filepath.Join(sd, "main.sock")
|
||||||
|
createSockFile(t, instanceSock)
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
expected := "unix://" + instanceSock
|
||||||
|
if got != expected {
|
||||||
|
t.Errorf("expected %s, got %s", expected, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_MultipleDiscovered(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
createSockFile(t, filepath.Join(sd, "main.sock"))
|
||||||
|
createSockFile(t, filepath.Join(sd, "other.sock"))
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_NoSocketsFound(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
sd := filepath.Join(tmp, "netbird")
|
||||||
|
if err := os.MkdirAll(sd, 0o755); err != nil {
|
||||||
|
t.Fatal(err)
|
||||||
|
}
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(sd)
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_NonUnixAddr(t *testing.T) {
|
||||||
|
addr := "tcp://127.0.0.1:41731"
|
||||||
|
got := ResolveUnixDaemonAddr(addr)
|
||||||
|
if got != addr {
|
||||||
|
t.Errorf("expected %s, got %s", addr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestResolveUnixDaemonAddr_ScanDirMissing(t *testing.T) {
|
||||||
|
tmp := t.TempDir()
|
||||||
|
|
||||||
|
defaultAddr := "unix://" + filepath.Join(tmp, "netbird.sock")
|
||||||
|
|
||||||
|
origScanDir := scanDir
|
||||||
|
setScanDir(filepath.Join(tmp, "nonexistent"))
|
||||||
|
t.Cleanup(func() { setScanDir(origScanDir) })
|
||||||
|
|
||||||
|
got := ResolveUnixDaemonAddr(defaultAddr)
|
||||||
|
if got != defaultAddr {
|
||||||
|
t.Errorf("expected original %s, got %s", defaultAddr, got)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -27,11 +27,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/anonymize"
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
"github.com/netbirdio/netbird/client/internal/profilemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager/installer"
|
"github.com/netbirdio/netbird/client/internal/updater/installer"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
const readmeContent = `Netbird debug bundle
|
const readmeContent = `Netbird debug bundle
|
||||||
@@ -53,6 +52,7 @@ resolved_domains.txt: Anonymized resolved domain IP addresses from the status re
|
|||||||
config.txt: Anonymized configuration information of the NetBird client.
|
config.txt: Anonymized configuration information of the NetBird client.
|
||||||
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
network_map.json: Anonymized sync response containing peer configurations, routes, DNS settings, and firewall rules.
|
||||||
state.json: Anonymized client state dump containing netbird states for the active profile.
|
state.json: Anonymized client state dump containing netbird states for the active profile.
|
||||||
|
metrics.txt: Buffered client metrics in InfluxDB line protocol format. Only present when metrics collection is enabled. Peer identifiers are anonymized.
|
||||||
mutex.prof: Mutex profiling information.
|
mutex.prof: Mutex profiling information.
|
||||||
goroutine.prof: Goroutine profiling information.
|
goroutine.prof: Goroutine profiling information.
|
||||||
block.prof: Block profiling information.
|
block.prof: Block profiling information.
|
||||||
@@ -219,6 +219,11 @@ const (
|
|||||||
darwinStdoutLogPath = "/var/log/netbird.err.log"
|
darwinStdoutLogPath = "/var/log/netbird.err.log"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// MetricsExporter is an interface for exporting metrics
|
||||||
|
type MetricsExporter interface {
|
||||||
|
Export(w io.Writer) error
|
||||||
|
}
|
||||||
|
|
||||||
type BundleGenerator struct {
|
type BundleGenerator struct {
|
||||||
anonymizer *anonymize.Anonymizer
|
anonymizer *anonymize.Anonymizer
|
||||||
|
|
||||||
@@ -229,6 +234,7 @@ type BundleGenerator struct {
|
|||||||
logPath string
|
logPath string
|
||||||
cpuProfile []byte
|
cpuProfile []byte
|
||||||
refreshStatus func() // Optional callback to refresh status before bundle generation
|
refreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
|
clientMetrics MetricsExporter
|
||||||
|
|
||||||
anonymize bool
|
anonymize bool
|
||||||
includeSystemInfo bool
|
includeSystemInfo bool
|
||||||
@@ -250,6 +256,7 @@ type GeneratorDependencies struct {
|
|||||||
LogPath string
|
LogPath string
|
||||||
CPUProfile []byte
|
CPUProfile []byte
|
||||||
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
RefreshStatus func() // Optional callback to refresh status before bundle generation
|
||||||
|
ClientMetrics MetricsExporter
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGenerator {
|
||||||
@@ -268,6 +275,7 @@ func NewBundleGenerator(deps GeneratorDependencies, cfg BundleConfig) *BundleGen
|
|||||||
logPath: deps.LogPath,
|
logPath: deps.LogPath,
|
||||||
cpuProfile: deps.CPUProfile,
|
cpuProfile: deps.CPUProfile,
|
||||||
refreshStatus: deps.RefreshStatus,
|
refreshStatus: deps.RefreshStatus,
|
||||||
|
clientMetrics: deps.ClientMetrics,
|
||||||
|
|
||||||
anonymize: cfg.Anonymize,
|
anonymize: cfg.Anonymize,
|
||||||
includeSystemInfo: cfg.IncludeSystemInfo,
|
includeSystemInfo: cfg.IncludeSystemInfo,
|
||||||
@@ -351,6 +359,10 @@ func (g *BundleGenerator) createArchive() error {
|
|||||||
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
log.Errorf("failed to add corrupted state files to debug bundle: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := g.addMetrics(); err != nil {
|
||||||
|
log.Errorf("failed to add metrics to debug bundle: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := g.addWgShow(); err != nil {
|
if err := g.addWgShow(); err != nil {
|
||||||
log.Errorf("failed to add wg show output: %v", err)
|
log.Errorf("failed to add wg show output: %v", err)
|
||||||
}
|
}
|
||||||
@@ -418,7 +430,10 @@ func (g *BundleGenerator) addStatus() error {
|
|||||||
fullStatus := g.statusRecorder.GetFullStatus()
|
fullStatus := g.statusRecorder.GetFullStatus()
|
||||||
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
protoFullStatus := nbstatus.ToProtoFullStatus(fullStatus)
|
||||||
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
protoFullStatus.Events = g.statusRecorder.GetEventHistory()
|
||||||
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, g.anonymize, version.NetbirdVersion(), "", nil, nil, nil, "", profName)
|
overview := nbstatus.ConvertToStatusOutputOverview(protoFullStatus, nbstatus.ConvertOptions{
|
||||||
|
Anonymize: g.anonymize,
|
||||||
|
ProfileName: profName,
|
||||||
|
})
|
||||||
statusOutput := overview.FullDetailSummary()
|
statusOutput := overview.FullDetailSummary()
|
||||||
|
|
||||||
statusReader := strings.NewReader(statusOutput)
|
statusReader := strings.NewReader(statusOutput)
|
||||||
@@ -744,6 +759,30 @@ func (g *BundleGenerator) addCorruptedStateFiles() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addMetrics() error {
|
||||||
|
if g.clientMetrics == nil {
|
||||||
|
log.Debugf("skipping metrics in debug bundle: no metrics collector")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
if err := g.clientMetrics.Export(&buf); err != nil {
|
||||||
|
return fmt.Errorf("export metrics: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if buf.Len() == 0 {
|
||||||
|
log.Debugf("skipping metrics.txt in debug bundle: no metrics data")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := g.addFileToZip(&buf, "metrics.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add metrics file to zip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("added metrics to debug bundle")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (g *BundleGenerator) addLogfile() error {
|
func (g *BundleGenerator) addLogfile() error {
|
||||||
if g.logPath == "" {
|
if g.logPath == "" {
|
||||||
log.Debugf("skipping empty log file in debug bundle")
|
log.Debugf("skipping empty log file in debug bundle")
|
||||||
|
|||||||
@@ -14,6 +14,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
|
"github.com/hashicorp/go-multierror"
|
||||||
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/exp/maps"
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
@@ -22,6 +24,7 @@ import (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
netbirdDNSStateKeyFormat = "State:/Network/Service/NetBird-%s/DNS"
|
||||||
|
netbirdDNSStateKeyIndexedFormat = "State:/Network/Service/NetBird-%s-%d/DNS"
|
||||||
globalIPv4State = "State:/Network/Global/IPv4"
|
globalIPv4State = "State:/Network/Global/IPv4"
|
||||||
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
primaryServiceStateKeyFormat = "State:/Network/Service/%s/DNS"
|
||||||
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
keySupplementalMatchDomains = "SupplementalMatchDomains"
|
||||||
@@ -35,6 +38,14 @@ const (
|
|||||||
searchSuffix = "Search"
|
searchSuffix = "Search"
|
||||||
matchSuffix = "Match"
|
matchSuffix = "Match"
|
||||||
localSuffix = "Local"
|
localSuffix = "Local"
|
||||||
|
|
||||||
|
// maxDomainsPerResolverEntry is the max number of domains per scutil resolver key.
|
||||||
|
// scutil's d.add has maxArgs=101 (key + * + 99 values), so 99 is the hard cap.
|
||||||
|
maxDomainsPerResolverEntry = 50
|
||||||
|
|
||||||
|
// maxDomainBytesPerResolverEntry is the max total bytes of domain strings per key.
|
||||||
|
// scutil has an undocumented ~2048 byte value buffer; we stay well under it.
|
||||||
|
maxDomainBytesPerResolverEntry = 1500
|
||||||
)
|
)
|
||||||
|
|
||||||
type systemConfigurator struct {
|
type systemConfigurator struct {
|
||||||
@@ -84,28 +95,23 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
if err := s.removeKeysContaining(matchSuffix); err != nil {
|
||||||
var err error
|
log.Warnf("failed to remove old match keys: %v", err)
|
||||||
if len(matchDomains) != 0 {
|
|
||||||
err = s.addMatchDomains(matchKey, strings.Join(matchDomains, " "), config.ServerIP, config.ServerPort)
|
|
||||||
} else {
|
|
||||||
log.Infof("removing match domains from the system")
|
|
||||||
err = s.removeKeyFromSystemConfig(matchKey)
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if len(matchDomains) != 0 {
|
||||||
return fmt.Errorf("add match domains: %w", err)
|
if err := s.addBatchedDomains(matchSuffix, matchDomains, config.ServerIP, config.ServerPort, false); err != nil {
|
||||||
|
return fmt.Errorf("add match domains: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
s.updateState(stateManager)
|
s.updateState(stateManager)
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
if err := s.removeKeysContaining(searchSuffix); err != nil {
|
||||||
if len(searchDomains) != 0 {
|
log.Warnf("failed to remove old search keys: %v", err)
|
||||||
err = s.addSearchDomains(searchKey, strings.Join(searchDomains, " "), config.ServerIP, config.ServerPort)
|
|
||||||
} else {
|
|
||||||
log.Infof("removing search domains from the system")
|
|
||||||
err = s.removeKeyFromSystemConfig(searchKey)
|
|
||||||
}
|
}
|
||||||
if err != nil {
|
if len(searchDomains) != 0 {
|
||||||
return fmt.Errorf("add search domains: %w", err)
|
if err := s.addBatchedDomains(searchSuffix, searchDomains, config.ServerIP, config.ServerPort, true); err != nil {
|
||||||
|
return fmt.Errorf("add search domains: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
s.updateState(stateManager)
|
s.updateState(stateManager)
|
||||||
|
|
||||||
@@ -149,8 +155,7 @@ func (s *systemConfigurator) restoreHostDNS() error {
|
|||||||
|
|
||||||
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
||||||
if len(s.createdKeys) == 0 {
|
if len(s.createdKeys) == 0 {
|
||||||
// return defaults for startup calls
|
return s.discoverExistingKeys()
|
||||||
return []string{getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix), getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
keys := make([]string, 0, len(s.createdKeys))
|
keys := make([]string, 0, len(s.createdKeys))
|
||||||
@@ -160,6 +165,47 @@ func (s *systemConfigurator) getRemovableKeysWithDefaults() []string {
|
|||||||
return keys
|
return keys
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// discoverExistingKeys probes scutil for all NetBird DNS keys that may exist.
|
||||||
|
// This handles the case where createdKeys is empty (e.g., state file lost after unclean shutdown).
|
||||||
|
func (s *systemConfigurator) discoverExistingKeys() []string {
|
||||||
|
dnsKeys, err := getSystemDNSKeys()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to get system DNS keys: %v", err)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
var keys []string
|
||||||
|
|
||||||
|
for _, suffix := range []string{searchSuffix, matchSuffix, localSuffix} {
|
||||||
|
key := getKeyWithInput(netbirdDNSStateKeyFormat, suffix)
|
||||||
|
if strings.Contains(dnsKeys, key) {
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, suffix := range []string{searchSuffix, matchSuffix} {
|
||||||
|
for i := 0; ; i++ {
|
||||||
|
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||||
|
if !strings.Contains(dnsKeys, key) {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
keys = append(keys, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return keys
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemDNSKeys gets all DNS keys
|
||||||
|
func getSystemDNSKeys() (string, error) {
|
||||||
|
command := "list .*DNS\nquit\n"
|
||||||
|
out, err := runSystemConfigCommand(command)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
|
}
|
||||||
|
return string(out), nil
|
||||||
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
func (s *systemConfigurator) removeKeyFromSystemConfig(key string) error {
|
||||||
line := buildRemoveKeyOperation(key)
|
line := buildRemoveKeyOperation(key)
|
||||||
_, err := runSystemConfigCommand(wrapCommand(line))
|
_, err := runSystemConfigCommand(wrapCommand(line))
|
||||||
@@ -184,12 +230,11 @@ func (s *systemConfigurator) addLocalDNS() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.addSearchDomains(
|
domainsStr := strings.Join(s.systemDNSSettings.Domains, " ")
|
||||||
localKey,
|
if err := s.addDNSState(localKey, domainsStr, s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort, true); err != nil {
|
||||||
strings.Join(s.systemDNSSettings.Domains, " "), s.systemDNSSettings.ServerIP, s.systemDNSSettings.ServerPort,
|
return fmt.Errorf("add local dns state: %w", err)
|
||||||
); err != nil {
|
|
||||||
return fmt.Errorf("add search domains: %w", err)
|
|
||||||
}
|
}
|
||||||
|
s.createdKeys[localKey] = struct{}{}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -280,28 +325,77 @@ func (s *systemConfigurator) getOriginalNameservers() []netip.Addr {
|
|||||||
return slices.Clone(s.origNameservers)
|
return slices.Clone(s.origNameservers)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addSearchDomains(key, domains string, ip netip.Addr, port int) error {
|
// splitDomainsIntoBatches splits domains into batches respecting both element count and byte size limits.
|
||||||
err := s.addDNSState(key, domains, ip, port, true)
|
func splitDomainsIntoBatches(domains []string) [][]string {
|
||||||
if err != nil {
|
if len(domains) == 0 {
|
||||||
return fmt.Errorf("add dns state: %w", err)
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d search domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
var batches [][]string
|
||||||
|
var current []string
|
||||||
|
currentBytes := 0
|
||||||
|
|
||||||
s.createdKeys[key] = struct{}{}
|
for _, d := range domains {
|
||||||
|
domainLen := len(d)
|
||||||
|
newBytes := currentBytes + domainLen
|
||||||
|
if currentBytes > 0 {
|
||||||
|
newBytes++ // space separator
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
if len(current) > 0 && (len(current) >= maxDomainsPerResolverEntry || newBytes > maxDomainBytesPerResolverEntry) {
|
||||||
|
batches = append(batches, current)
|
||||||
|
current = nil
|
||||||
|
currentBytes = 0
|
||||||
|
}
|
||||||
|
|
||||||
|
current = append(current, d)
|
||||||
|
if currentBytes > 0 {
|
||||||
|
currentBytes += 1 + domainLen
|
||||||
|
} else {
|
||||||
|
currentBytes = domainLen
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(current) > 0 {
|
||||||
|
batches = append(batches, current)
|
||||||
|
}
|
||||||
|
|
||||||
|
return batches
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemConfigurator) addMatchDomains(key, domains string, dnsServer netip.Addr, port int) error {
|
// removeKeysContaining removes all created keys that contain the given substring.
|
||||||
err := s.addDNSState(key, domains, dnsServer, port, false)
|
func (s *systemConfigurator) removeKeysContaining(suffix string) error {
|
||||||
if err != nil {
|
var toRemove []string
|
||||||
return fmt.Errorf("add dns state: %w", err)
|
for key := range s.createdKeys {
|
||||||
|
if strings.Contains(key, suffix) {
|
||||||
|
toRemove = append(toRemove, key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
var multiErr *multierror.Error
|
||||||
|
for _, key := range toRemove {
|
||||||
|
if err := s.removeKeyFromSystemConfig(key); err != nil {
|
||||||
|
multiErr = multierror.Append(multiErr, fmt.Errorf("couldn't remove key %s: %w", key, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(multiErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// addBatchedDomains splits domains into batches and creates indexed scutil keys for each batch.
|
||||||
|
func (s *systemConfigurator) addBatchedDomains(suffix string, domains []string, ip netip.Addr, port int, enableSearch bool) error {
|
||||||
|
batches := splitDomainsIntoBatches(domains)
|
||||||
|
|
||||||
|
for i, batch := range batches {
|
||||||
|
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, suffix, i)
|
||||||
|
domainsStr := strings.Join(batch, " ")
|
||||||
|
|
||||||
|
if err := s.addDNSState(key, domainsStr, ip, port, enableSearch); err != nil {
|
||||||
|
return fmt.Errorf("add dns state for batch %d: %w", i, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.createdKeys[key] = struct{}{}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d match domains to the state. Domain list: %s", len(strings.Split(domains, " ")), domains)
|
log.Infof("added %d %s domains across %d resolver entries", len(domains), suffix, len(batches))
|
||||||
|
|
||||||
s.createdKeys[key] = struct{}{}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -364,7 +458,6 @@ func (s *systemConfigurator) flushDNSCache() error {
|
|||||||
if out, err := cmd.CombinedOutput(); err != nil {
|
if out, err := cmd.CombinedOutput(); err != nil {
|
||||||
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
return fmt.Errorf("restart mDNSResponder: %w, output: %s", err, out)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("flushed DNS cache")
|
log.Info("flushed DNS cache")
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,10 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bufio"
|
||||||
|
"bytes"
|
||||||
"context"
|
"context"
|
||||||
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"path/filepath"
|
"path/filepath"
|
||||||
@@ -49,17 +52,22 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, sm.PersistState(context.Background()))
|
require.NoError(t, sm.PersistState(context.Background()))
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
|
||||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
||||||
|
|
||||||
|
// Collect all created keys for cleanup verification
|
||||||
|
createdKeys := make([]string, 0, len(configurator.createdKeys))
|
||||||
|
for key := range configurator.createdKeys {
|
||||||
|
createdKeys = append(createdKeys, key)
|
||||||
|
}
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for _, key := range createdKeys {
|
||||||
_ = removeTestDNSKey(key)
|
_ = removeTestDNSKey(key)
|
||||||
}
|
}
|
||||||
|
_ = removeTestDNSKey(localKey)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for _, key := range createdKeys {
|
||||||
exists, err := checkDNSKeyExists(key)
|
exists, err := checkDNSKeyExists(key)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
if exists {
|
if exists {
|
||||||
@@ -83,13 +91,223 @@ func TestDarwinDNSUncleanShutdownCleanup(t *testing.T) {
|
|||||||
err = shutdownState.Cleanup()
|
err = shutdownState.Cleanup()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for _, key := range createdKeys {
|
||||||
exists, err := checkDNSKeyExists(key)
|
exists, err := checkDNSKeyExists(key)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
assert.False(t, exists, "Key %s should NOT exist after cleanup", key)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// generateShortDomains generates domains like a.com, b.com, ..., aa.com, ab.com, etc.
|
||||||
|
func generateShortDomains(count int) []string {
|
||||||
|
domains := make([]string, 0, count)
|
||||||
|
for i := range count {
|
||||||
|
label := ""
|
||||||
|
n := i
|
||||||
|
for {
|
||||||
|
label = string(rune('a'+n%26)) + label
|
||||||
|
n = n/26 - 1
|
||||||
|
if n < 0 {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
domains = append(domains, label+".com")
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateLongDomains generates domains like subdomain-000.department.organization-name.example.com
|
||||||
|
func generateLongDomains(count int) []string {
|
||||||
|
domains := make([]string, 0, count)
|
||||||
|
for i := range count {
|
||||||
|
domains = append(domains, fmt.Sprintf("subdomain-%03d.department.organization-name.example.com", i))
|
||||||
|
}
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
// readDomainsFromKey reads the SupplementalMatchDomains array back from scutil for a given key.
|
||||||
|
func readDomainsFromKey(t *testing.T, key string) []string {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
cmd := exec.Command(scutilPath)
|
||||||
|
cmd.Stdin = strings.NewReader(fmt.Sprintf("open\nshow %s\nquit\n", key))
|
||||||
|
out, err := cmd.Output()
|
||||||
|
require.NoError(t, err, "scutil show should succeed")
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
inArray := false
|
||||||
|
scanner := bufio.NewScanner(bytes.NewReader(out))
|
||||||
|
for scanner.Scan() {
|
||||||
|
line := strings.TrimSpace(scanner.Text())
|
||||||
|
if strings.HasPrefix(line, "SupplementalMatchDomains") && strings.Contains(line, "<array>") {
|
||||||
|
inArray = true
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if inArray {
|
||||||
|
if line == "}" {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
// lines look like: "0 : a.com"
|
||||||
|
parts := strings.SplitN(line, " : ", 2)
|
||||||
|
if len(parts) == 2 {
|
||||||
|
domains = append(domains, parts[1])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
require.NoError(t, scanner.Err())
|
||||||
|
return domains
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSplitDomainsIntoBatches(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
domains []string
|
||||||
|
expectedCount int
|
||||||
|
checkAllPresent bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "empty",
|
||||||
|
domains: nil,
|
||||||
|
expectedCount: 0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "under_limit",
|
||||||
|
domains: generateShortDomains(10),
|
||||||
|
expectedCount: 1,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "at_element_limit",
|
||||||
|
domains: generateShortDomains(50),
|
||||||
|
expectedCount: 1,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "over_element_limit",
|
||||||
|
domains: generateShortDomains(51),
|
||||||
|
expectedCount: 2,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "triple_element_limit",
|
||||||
|
domains: generateShortDomains(150),
|
||||||
|
expectedCount: 3,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "long_domains_hit_byte_limit",
|
||||||
|
domains: generateLongDomains(50),
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500_short_domains",
|
||||||
|
domains: generateShortDomains(500),
|
||||||
|
expectedCount: 10,
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "500_long_domains",
|
||||||
|
domains: generateLongDomains(500),
|
||||||
|
checkAllPresent: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tests {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
batches := splitDomainsIntoBatches(tc.domains)
|
||||||
|
|
||||||
|
if tc.expectedCount > 0 {
|
||||||
|
assert.Len(t, batches, tc.expectedCount, "expected %d batches", tc.expectedCount)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify each batch respects limits
|
||||||
|
for i, batch := range batches {
|
||||||
|
assert.LessOrEqual(t, len(batch), maxDomainsPerResolverEntry,
|
||||||
|
"batch %d exceeds element limit", i)
|
||||||
|
|
||||||
|
totalBytes := 0
|
||||||
|
for j, d := range batch {
|
||||||
|
if j > 0 {
|
||||||
|
totalBytes++
|
||||||
|
}
|
||||||
|
totalBytes += len(d)
|
||||||
|
}
|
||||||
|
assert.LessOrEqual(t, totalBytes, maxDomainBytesPerResolverEntry,
|
||||||
|
"batch %d exceeds byte limit (%d bytes)", i, totalBytes)
|
||||||
|
}
|
||||||
|
|
||||||
|
if tc.checkAllPresent {
|
||||||
|
var all []string
|
||||||
|
for _, batch := range batches {
|
||||||
|
all = append(all, batch...)
|
||||||
|
}
|
||||||
|
assert.Equal(t, tc.domains, all, "all domains should be present in order")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestMatchDomainBatching writes increasing numbers of domains via the batching mechanism
|
||||||
|
// and verifies all domains are readable across multiple scutil keys.
|
||||||
|
func TestMatchDomainBatching(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping scutil integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
count int
|
||||||
|
generator func(int) []string
|
||||||
|
}{
|
||||||
|
{"short_10", 10, generateShortDomains},
|
||||||
|
{"short_50", 50, generateShortDomains},
|
||||||
|
{"short_100", 100, generateShortDomains},
|
||||||
|
{"short_200", 200, generateShortDomains},
|
||||||
|
{"short_500", 500, generateShortDomains},
|
||||||
|
{"long_10", 10, generateLongDomains},
|
||||||
|
{"long_50", 50, generateLongDomains},
|
||||||
|
{"long_100", 100, generateLongDomains},
|
||||||
|
{"long_200", 200, generateLongDomains},
|
||||||
|
{"long_500", 500, generateLongDomains},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
configurator := &systemConfigurator{
|
||||||
|
createdKeys: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
for key := range configurator.createdKeys {
|
||||||
|
_ = removeTestDNSKey(key)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
domains := tc.generator(tc.count)
|
||||||
|
err := configurator.addBatchedDomains(matchSuffix, domains, netip.MustParseAddr("100.64.0.1"), 53, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
batches := splitDomainsIntoBatches(domains)
|
||||||
|
t.Logf("wrote %d domains across %d batched keys", tc.count, len(batches))
|
||||||
|
|
||||||
|
// Read back all domains from all batched keys
|
||||||
|
var got []string
|
||||||
|
for i := range batches {
|
||||||
|
key := fmt.Sprintf(netbirdDNSStateKeyIndexedFormat, matchSuffix, i)
|
||||||
|
exists, err := checkDNSKeyExists(key)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, exists, "key %s should exist", key)
|
||||||
|
|
||||||
|
got = append(got, readDomainsFromKey(t, key)...)
|
||||||
|
}
|
||||||
|
|
||||||
|
t.Logf("read back %d/%d domains from %d keys", len(got), tc.count, len(batches))
|
||||||
|
assert.Equal(t, tc.count, len(got), "all domains should be readable")
|
||||||
|
assert.Equal(t, domains, got, "domains should match in order")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func checkDNSKeyExists(key string) (bool, error) {
|
func checkDNSKeyExists(key string) (bool, error) {
|
||||||
cmd := exec.Command(scutilPath)
|
cmd := exec.Command(scutilPath)
|
||||||
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
cmd.Stdin = strings.NewReader("show " + key + "\nquit\n")
|
||||||
@@ -158,15 +376,15 @@ func setupTestConfigurator(t *testing.T) (*systemConfigurator, *statemanager.Man
|
|||||||
createdKeys: make(map[string]struct{}),
|
createdKeys: make(map[string]struct{}),
|
||||||
}
|
}
|
||||||
|
|
||||||
searchKey := getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix)
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
|
||||||
localKey := getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix)
|
|
||||||
|
|
||||||
cleanup := func() {
|
cleanup := func() {
|
||||||
_ = sm.Stop(context.Background())
|
_ = sm.Stop(context.Background())
|
||||||
for _, key := range []string{searchKey, matchKey, localKey} {
|
for key := range configurator.createdKeys {
|
||||||
_ = removeTestDNSKey(key)
|
_ = removeTestDNSKey(key)
|
||||||
}
|
}
|
||||||
|
// Also clean up old-format keys and local key in case they exist
|
||||||
|
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, searchSuffix))
|
||||||
|
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix))
|
||||||
|
_ = removeTestDNSKey(getKeyWithInput(netbirdDNSStateKeyFormat, localSuffix))
|
||||||
}
|
}
|
||||||
|
|
||||||
return configurator, sm, cleanup
|
return configurator, sm, cleanup
|
||||||
|
|||||||
@@ -42,6 +42,8 @@ const (
|
|||||||
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
dnsPolicyConfigConfigOptionsKey = "ConfigOptions"
|
||||||
dnsPolicyConfigConfigOptionsValue = 0x8
|
dnsPolicyConfigConfigOptionsValue = 0x8
|
||||||
|
|
||||||
|
nrptMaxDomainsPerRule = 50
|
||||||
|
|
||||||
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
interfaceConfigPath = `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces`
|
||||||
interfaceConfigNameServerKey = "NameServer"
|
interfaceConfigNameServerKey = "NameServer"
|
||||||
interfaceConfigSearchListKey = "SearchList"
|
interfaceConfigSearchListKey = "SearchList"
|
||||||
@@ -198,10 +200,11 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
count, err := r.addDNSMatchPolicy(matchDomains, config.ServerIP)
|
||||||
|
// Update count even on error to ensure cleanup covers partially created rules
|
||||||
|
r.nrptEntryCount = count
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("add dns match policy: %w", err)
|
return fmt.Errorf("add dns match policy: %w", err)
|
||||||
}
|
}
|
||||||
r.nrptEntryCount = count
|
|
||||||
} else {
|
} else {
|
||||||
r.nrptEntryCount = 0
|
r.nrptEntryCount = 0
|
||||||
}
|
}
|
||||||
@@ -239,23 +242,33 @@ func (r *registryConfigurator) addDNSSetupForAll(ip netip.Addr) error {
|
|||||||
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr) (int, error) {
|
||||||
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
// if the gpo key is present, we need to put our DNS settings there, otherwise our config might be ignored
|
||||||
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
// see https://learn.microsoft.com/en-us/openspecs/windows_protocols/ms-gpnrpt/8cc31cb9-20cb-4140-9e85-3e08703b4745
|
||||||
for i, domain := range domains {
|
|
||||||
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i)
|
|
||||||
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, i)
|
|
||||||
|
|
||||||
singleDomain := []string{domain}
|
// We need to batch domains into chunks and create one NRPT rule per batch.
|
||||||
|
ruleIndex := 0
|
||||||
|
for i := 0; i < len(domains); i += nrptMaxDomainsPerRule {
|
||||||
|
end := i + nrptMaxDomainsPerRule
|
||||||
|
if end > len(domains) {
|
||||||
|
end = len(domains)
|
||||||
|
}
|
||||||
|
batchDomains := domains[i:end]
|
||||||
|
|
||||||
if err := r.configureDNSPolicy(localPath, singleDomain, ip); err != nil {
|
localPath := fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, ruleIndex)
|
||||||
return i, fmt.Errorf("configure DNS Local policy for domain %s: %w", domain, err)
|
gpoPath := fmt.Sprintf("%s-%d", gpoDnsPolicyConfigMatchPath, ruleIndex)
|
||||||
|
|
||||||
|
if err := r.configureDNSPolicy(localPath, batchDomains, ip); err != nil {
|
||||||
|
return ruleIndex, fmt.Errorf("configure DNS Local policy for rule %d: %w", ruleIndex, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Increment immediately so the caller's cleanup path knows about this rule
|
||||||
|
ruleIndex++
|
||||||
|
|
||||||
if r.gpo {
|
if r.gpo {
|
||||||
if err := r.configureDNSPolicy(gpoPath, singleDomain, ip); err != nil {
|
if err := r.configureDNSPolicy(gpoPath, batchDomains, ip); err != nil {
|
||||||
return i, fmt.Errorf("configure gpo DNS policy: %w", err)
|
return ruleIndex, fmt.Errorf("configure gpo DNS policy for rule %d: %w", ruleIndex-1, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("added NRPT entry for domain: %s", domain)
|
log.Debugf("added NRPT rule %d with %d domains", ruleIndex-1, len(batchDomains))
|
||||||
}
|
}
|
||||||
|
|
||||||
if r.gpo {
|
if r.gpo {
|
||||||
@@ -264,8 +277,8 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip netip.Addr
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("added %d separate NRPT entries. Domain list: %s", len(domains), domains)
|
log.Infof("added %d NRPT rules for %d domains", ruleIndex, len(domains))
|
||||||
return len(domains), nil
|
return ruleIndex, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
func (r *registryConfigurator) configureDNSPolicy(policyPath string, domains []string, ip netip.Addr) error {
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
// TestNRPTEntriesCleanupOnConfigChange tests that old NRPT entries are properly cleaned up
|
||||||
// when the number of match domains decreases between configuration changes.
|
// when the number of match domains decreases between configuration changes.
|
||||||
|
// With batching enabled (50 domains per rule), we need enough domains to create multiple rules.
|
||||||
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
||||||
if testing.Short() {
|
if testing.Short() {
|
||||||
t.Skip("skipping registry integration test in short mode")
|
t.Skip("skipping registry integration test in short mode")
|
||||||
@@ -37,51 +38,60 @@ func TestNRPTEntriesCleanupOnConfigChange(t *testing.T) {
|
|||||||
gpo: false,
|
gpo: false,
|
||||||
}
|
}
|
||||||
|
|
||||||
config5 := HostDNSConfig{
|
// Create 125 domains which will result in 3 NRPT rules (50+50+25)
|
||||||
ServerIP: testIP,
|
domains125 := make([]DomainConfig, 125)
|
||||||
Domains: []DomainConfig{
|
for i := 0; i < 125; i++ {
|
||||||
{Domain: "domain1.com", MatchOnly: true},
|
domains125[i] = DomainConfig{
|
||||||
{Domain: "domain2.com", MatchOnly: true},
|
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||||
{Domain: "domain3.com", MatchOnly: true},
|
MatchOnly: true,
|
||||||
{Domain: "domain4.com", MatchOnly: true},
|
}
|
||||||
{Domain: "domain5.com", MatchOnly: true},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = cfg.applyDNSConfig(config5, nil)
|
config125 := HostDNSConfig{
|
||||||
|
ServerIP: testIP,
|
||||||
|
Domains: domains125,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = cfg.applyDNSConfig(config125, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify all 5 entries exist
|
// Verify 3 NRPT rules exist
|
||||||
for i := 0; i < 5; i++ {
|
assert.Equal(t, 3, cfg.nrptEntryCount, "Should create 3 NRPT rules for 125 domains")
|
||||||
|
for i := 0; i < 3; i++ {
|
||||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, exists, "Entry %d should exist after first config", i)
|
assert.True(t, exists, "NRPT rule %d should exist after first config", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
config2 := HostDNSConfig{
|
// Reduce to 75 domains which will result in 2 NRPT rules (50+25)
|
||||||
|
domains75 := make([]DomainConfig, 75)
|
||||||
|
for i := 0; i < 75; i++ {
|
||||||
|
domains75[i] = DomainConfig{
|
||||||
|
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||||
|
MatchOnly: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config75 := HostDNSConfig{
|
||||||
ServerIP: testIP,
|
ServerIP: testIP,
|
||||||
Domains: []DomainConfig{
|
Domains: domains75,
|
||||||
{Domain: "domain1.com", MatchOnly: true},
|
|
||||||
{Domain: "domain2.com", MatchOnly: true},
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = cfg.applyDNSConfig(config2, nil)
|
err = cfg.applyDNSConfig(config75, nil)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify first 2 entries exist
|
// Verify first 2 NRPT rules exist
|
||||||
|
assert.Equal(t, 2, cfg.nrptEntryCount, "Should create 2 NRPT rules for 75 domains")
|
||||||
for i := 0; i < 2; i++ {
|
for i := 0; i < 2; i++ {
|
||||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
assert.True(t, exists, "Entry %d should exist after second config", i)
|
assert.True(t, exists, "NRPT rule %d should exist after second config", i)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify entries 2-4 are cleaned up
|
// Verify rule 2 is cleaned up
|
||||||
for i := 2; i < 5; i++ {
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, 2))
|
||||||
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
require.NoError(t, err)
|
||||||
require.NoError(t, err)
|
assert.False(t, exists, "NRPT rule 2 should NOT exist after reducing to 75 domains")
|
||||||
assert.False(t, exists, "Entry %d should NOT exist after reducing to 2 domains", i)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func registryKeyExists(path string) (bool, error) {
|
func registryKeyExists(path string) (bool, error) {
|
||||||
@@ -97,6 +107,106 @@ func registryKeyExists(path string) (bool, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func cleanupRegistryKeys(*testing.T) {
|
func cleanupRegistryKeys(*testing.T) {
|
||||||
cfg := ®istryConfigurator{nrptEntryCount: 10}
|
// Clean up more entries to account for batching tests with many domains
|
||||||
|
cfg := ®istryConfigurator{nrptEntryCount: 20}
|
||||||
_ = cfg.removeDNSMatchPolicies()
|
_ = cfg.removeDNSMatchPolicies()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TestNRPTDomainBatching verifies that domains are correctly batched into NRPT rules.
|
||||||
|
func TestNRPTDomainBatching(t *testing.T) {
|
||||||
|
if testing.Short() {
|
||||||
|
t.Skip("skipping registry integration test in short mode")
|
||||||
|
}
|
||||||
|
|
||||||
|
defer cleanupRegistryKeys(t)
|
||||||
|
cleanupRegistryKeys(t)
|
||||||
|
|
||||||
|
testIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
|
||||||
|
// Create a test interface registry key so updateSearchDomains doesn't fail
|
||||||
|
testGUID := "{12345678-1234-1234-1234-123456789ABC}"
|
||||||
|
interfacePath := `SYSTEM\CurrentControlSet\Services\Tcpip\Parameters\Interfaces\` + testGUID
|
||||||
|
testKey, _, err := registry.CreateKey(registry.LOCAL_MACHINE, interfacePath, registry.SET_VALUE)
|
||||||
|
require.NoError(t, err, "Should create test interface registry key")
|
||||||
|
testKey.Close()
|
||||||
|
defer func() {
|
||||||
|
_ = registry.DeleteKey(registry.LOCAL_MACHINE, interfacePath)
|
||||||
|
}()
|
||||||
|
|
||||||
|
cfg := ®istryConfigurator{
|
||||||
|
guid: testGUID,
|
||||||
|
gpo: false,
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
domainCount int
|
||||||
|
expectedRuleCount int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Less than 50 domains (single rule)",
|
||||||
|
domainCount: 30,
|
||||||
|
expectedRuleCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Exactly 50 domains (single rule)",
|
||||||
|
domainCount: 50,
|
||||||
|
expectedRuleCount: 1,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "51 domains (two rules)",
|
||||||
|
domainCount: 51,
|
||||||
|
expectedRuleCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "100 domains (two rules)",
|
||||||
|
domainCount: 100,
|
||||||
|
expectedRuleCount: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "125 domains (three rules: 50+50+25)",
|
||||||
|
domainCount: 125,
|
||||||
|
expectedRuleCount: 3,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Clean up before each subtest
|
||||||
|
cleanupRegistryKeys(t)
|
||||||
|
|
||||||
|
// Generate domains
|
||||||
|
domains := make([]DomainConfig, tc.domainCount)
|
||||||
|
for i := 0; i < tc.domainCount; i++ {
|
||||||
|
domains[i] = DomainConfig{
|
||||||
|
Domain: fmt.Sprintf("domain%d.com", i+1),
|
||||||
|
MatchOnly: true,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
config := HostDNSConfig{
|
||||||
|
ServerIP: testIP,
|
||||||
|
Domains: domains,
|
||||||
|
}
|
||||||
|
|
||||||
|
err := cfg.applyDNSConfig(config, nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify that exactly expectedRuleCount rules were created
|
||||||
|
assert.Equal(t, tc.expectedRuleCount, cfg.nrptEntryCount,
|
||||||
|
"Should create %d NRPT rules for %d domains", tc.expectedRuleCount, tc.domainCount)
|
||||||
|
|
||||||
|
// Verify all expected rules exist
|
||||||
|
for i := 0; i < tc.expectedRuleCount; i++ {
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, i))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.True(t, exists, "NRPT rule %d should exist", i)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no extra rules were created
|
||||||
|
exists, err := registryKeyExists(fmt.Sprintf("%s-%d", dnsPolicyConfigMatchPath, tc.expectedRuleCount))
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.False(t, exists, "No NRPT rule should exist at index %d", tc.expectedRuleCount)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -77,7 +77,7 @@ func (d *Resolver) ID() types.HandlerID {
|
|||||||
return "local-resolver"
|
return "local-resolver"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *Resolver) ProbeAvailability() {}
|
func (d *Resolver) ProbeAvailability(context.Context) {}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (d *Resolver) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
|
|||||||
@@ -376,9 +376,9 @@ func (m *Resolver) extractDomainsFromServerDomains(serverDomains dnsconfig.Serve
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if serverDomains.Flow != "" {
|
// Flow receiver domain is intentionally excluded from caching.
|
||||||
domains = append(domains, serverDomains.Flow)
|
// Cloud providers may rotate the IP behind this domain; a stale cached record
|
||||||
}
|
// causes TLS certificate verification failures on reconnect.
|
||||||
|
|
||||||
for _, stun := range serverDomains.Stuns {
|
for _, stun := range serverDomains.Stuns {
|
||||||
if stun != "" {
|
if stun != "" {
|
||||||
|
|||||||
@@ -391,7 +391,8 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
|||||||
}
|
}
|
||||||
assert.Len(t, resolver.GetCachedDomains(), 3)
|
assert.Len(t, resolver.GetCachedDomains(), 3)
|
||||||
|
|
||||||
// Update with partial ServerDomains (only flow domain - new type, should preserve all existing)
|
// Update with partial ServerDomains (only flow domain - flow is intentionally excluded from
|
||||||
|
// caching to prevent TLS failures from stale records, so all existing domains are preserved)
|
||||||
partialDomains := dnsconfig.ServerDomains{
|
partialDomains := dnsconfig.ServerDomains{
|
||||||
Flow: "github.com",
|
Flow: "github.com",
|
||||||
}
|
}
|
||||||
@@ -400,10 +401,10 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
|||||||
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
t.Skipf("Skipping test due to DNS resolution failure: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
assert.Len(t, removedDomains, 0, "Should not remove any domains when adding new type")
|
assert.Len(t, removedDomains, 0, "Should not remove any domains when only flow domain is provided")
|
||||||
|
|
||||||
finalDomains := resolver.GetCachedDomains()
|
finalDomains := resolver.GetCachedDomains()
|
||||||
assert.Len(t, finalDomains, 4, "Should have all original domains plus new flow domain")
|
assert.Len(t, finalDomains, 3, "Flow domain is not cached; all original domains should be preserved")
|
||||||
|
|
||||||
domainStrings := make([]string, len(finalDomains))
|
domainStrings := make([]string, len(finalDomains))
|
||||||
for i, d := range finalDomains {
|
for i, d := range finalDomains {
|
||||||
@@ -412,5 +413,5 @@ func TestResolver_PartialUpdateAddsNewTypePreservesExisting(t *testing.T) {
|
|||||||
assert.Contains(t, domainStrings, "example.org")
|
assert.Contains(t, domainStrings, "example.org")
|
||||||
assert.Contains(t, domainStrings, "google.com")
|
assert.Contains(t, domainStrings, "google.com")
|
||||||
assert.Contains(t, domainStrings, "cloudflare.com")
|
assert.Contains(t, domainStrings, "cloudflare.com")
|
||||||
assert.Contains(t, domainStrings, "github.com")
|
assert.NotContains(t, domainStrings, "github.com")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -84,3 +84,23 @@ func (m *MockServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
|||||||
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
func (m *MockServer) PopulateManagementDomain(mgmtURL *url.URL) error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRouteChecker mock implementation of SetRouteChecker from Server interface
|
||||||
|
func (m *MockServer) SetRouteChecker(func(netip.Addr) bool) {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginBatch mock implementation of BeginBatch from Server interface
|
||||||
|
func (m *MockServer) BeginBatch() {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndBatch mock implementation of EndBatch from Server interface
|
||||||
|
func (m *MockServer) EndBatch() {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|
||||||
|
// CancelBatch mock implementation of CancelBatch from Server interface
|
||||||
|
func (m *MockServer) CancelBatch() {
|
||||||
|
// Mock implementation - no-op
|
||||||
|
}
|
||||||
|
|||||||
@@ -6,7 +6,9 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -27,6 +29,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const envSkipDNSProbe = "NB_SKIP_DNS_PROBE"
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
type ReadyListener interface {
|
type ReadyListener interface {
|
||||||
OnReady()
|
OnReady()
|
||||||
@@ -41,6 +45,9 @@ type IosDnsManager interface {
|
|||||||
type Server interface {
|
type Server interface {
|
||||||
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||||
DeregisterHandler(domains domain.List, priority int)
|
DeregisterHandler(domains domain.List, priority int)
|
||||||
|
BeginBatch()
|
||||||
|
EndBatch()
|
||||||
|
CancelBatch()
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() netip.Addr
|
DnsIP() netip.Addr
|
||||||
@@ -50,6 +57,7 @@ type Server interface {
|
|||||||
ProbeAvailability()
|
ProbeAvailability()
|
||||||
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
UpdateServerConfig(domains dnsconfig.ServerDomains) error
|
||||||
PopulateManagementDomain(mgmtURL *url.URL) error
|
PopulateManagementDomain(mgmtURL *url.URL) error
|
||||||
|
SetRouteChecker(func(netip.Addr) bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type nsGroupsByDomain struct {
|
type nsGroupsByDomain struct {
|
||||||
@@ -83,6 +91,7 @@ type DefaultServer struct {
|
|||||||
currentConfigHash uint64
|
currentConfigHash uint64
|
||||||
handlerChain *HandlerChain
|
handlerChain *HandlerChain
|
||||||
extraDomains map[domain.Domain]int
|
extraDomains map[domain.Domain]int
|
||||||
|
batchMode bool
|
||||||
|
|
||||||
mgmtCacheResolver *mgmt.Resolver
|
mgmtCacheResolver *mgmt.Resolver
|
||||||
|
|
||||||
@@ -96,12 +105,17 @@ type DefaultServer struct {
|
|||||||
|
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
|
routeMatch func(netip.Addr) bool
|
||||||
|
|
||||||
|
probeMu sync.Mutex
|
||||||
|
probeCancel context.CancelFunc
|
||||||
|
probeWg sync.WaitGroup
|
||||||
}
|
}
|
||||||
|
|
||||||
type handlerWithStop interface {
|
type handlerWithStop interface {
|
||||||
dns.Handler
|
dns.Handler
|
||||||
Stop()
|
Stop()
|
||||||
ProbeAvailability()
|
ProbeAvailability(context.Context)
|
||||||
ID() types.HandlerID
|
ID() types.HandlerID
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -217,6 +231,14 @@ func newDefaultServer(
|
|||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetRouteChecker sets the function used by upstream resolvers to determine
|
||||||
|
// whether an IP is routed through the tunnel.
|
||||||
|
func (s *DefaultServer) SetRouteChecker(f func(netip.Addr) bool) {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
s.routeMatch = f
|
||||||
|
}
|
||||||
|
|
||||||
// RegisterHandler registers a handler for the given domains with the given priority.
|
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||||
// Any previously registered handler for the same domain and priority will be replaced.
|
// Any previously registered handler for the same domain and priority will be replaced.
|
||||||
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||||
@@ -230,7 +252,9 @@ func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler
|
|||||||
// convert to zone with simple ref counter
|
// convert to zone with simple ref counter
|
||||||
s.extraDomains[toZone(domain)]++
|
s.extraDomains[toZone(domain)]++
|
||||||
}
|
}
|
||||||
s.applyHostConfig()
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
@@ -259,9 +283,41 @@ func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
|||||||
delete(s.extraDomains, zone)
|
delete(s.extraDomains, zone)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
if !s.batchMode {
|
||||||
|
s.applyHostConfig()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BeginBatch starts batch mode for DNS handler registration/deregistration.
|
||||||
|
// In batch mode, applyHostConfig() is not called after each handler operation,
|
||||||
|
// allowing multiple handlers to be registered/deregistered efficiently.
|
||||||
|
// Must be followed by EndBatch() to apply the accumulated changes.
|
||||||
|
func (s *DefaultServer) BeginBatch() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
log.Debugf("DNS batch mode enabled")
|
||||||
|
s.batchMode = true
|
||||||
|
}
|
||||||
|
|
||||||
|
// EndBatch ends batch mode and applies all accumulated DNS configuration changes.
|
||||||
|
func (s *DefaultServer) EndBatch() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
log.Debugf("DNS batch mode disabled, applying accumulated changes")
|
||||||
|
s.batchMode = false
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// CancelBatch cancels batch mode without applying accumulated changes.
|
||||||
|
// This is useful when operations fail partway through and you want to
|
||||||
|
// discard partial state rather than applying it.
|
||||||
|
func (s *DefaultServer) CancelBatch() {
|
||||||
|
s.mux.Lock()
|
||||||
|
defer s.mux.Unlock()
|
||||||
|
log.Debugf("DNS batch mode cancelled, discarding accumulated changes")
|
||||||
|
s.batchMode = false
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||||
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
|
log.Debugf("deregistering handler with priority %d for %v", priority, domains)
|
||||||
|
|
||||||
@@ -320,7 +376,13 @@ func (s *DefaultServer) DnsIP() netip.Addr {
|
|||||||
|
|
||||||
// Stop stops the server
|
// Stop stops the server
|
||||||
func (s *DefaultServer) Stop() {
|
func (s *DefaultServer) Stop() {
|
||||||
|
s.probeMu.Lock()
|
||||||
|
if s.probeCancel != nil {
|
||||||
|
s.probeCancel()
|
||||||
|
}
|
||||||
s.ctxCancel()
|
s.ctxCancel()
|
||||||
|
s.probeMu.Unlock()
|
||||||
|
s.probeWg.Wait()
|
||||||
s.shutdownWg.Wait()
|
s.shutdownWg.Wait()
|
||||||
|
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
@@ -437,17 +499,66 @@ func (s *DefaultServer) SearchDomains() []string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// ProbeAvailability tests each upstream group's servers for availability
|
// ProbeAvailability tests each upstream group's servers for availability
|
||||||
// and deactivates the group if no server responds
|
// and deactivates the group if no server responds.
|
||||||
|
// If a previous probe is still running, it will be cancelled before starting a new one.
|
||||||
func (s *DefaultServer) ProbeAvailability() {
|
func (s *DefaultServer) ProbeAvailability() {
|
||||||
var wg sync.WaitGroup
|
if val := os.Getenv(envSkipDNSProbe); val != "" {
|
||||||
for _, mux := range s.dnsMuxMap {
|
skipProbe, err := strconv.ParseBool(val)
|
||||||
wg.Add(1)
|
if err != nil {
|
||||||
go func(mux handlerWithStop) {
|
log.Warnf("failed to parse %s: %v", envSkipDNSProbe, err)
|
||||||
defer wg.Done()
|
}
|
||||||
mux.ProbeAvailability()
|
if skipProbe {
|
||||||
}(mux.handler)
|
log.Infof("skipping DNS probe due to %s", envSkipDNSProbe)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
s.probeMu.Lock()
|
||||||
|
|
||||||
|
// don't start probes on a stopped server
|
||||||
|
if s.ctx.Err() != nil {
|
||||||
|
s.probeMu.Unlock()
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// cancel any running probe
|
||||||
|
if s.probeCancel != nil {
|
||||||
|
s.probeCancel()
|
||||||
|
s.probeCancel = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// wait for the previous probe goroutines to finish while holding
|
||||||
|
// the mutex so no other caller can start a new probe concurrently
|
||||||
|
s.probeWg.Wait()
|
||||||
|
|
||||||
|
// start a new probe
|
||||||
|
probeCtx, probeCancel := context.WithCancel(s.ctx)
|
||||||
|
s.probeCancel = probeCancel
|
||||||
|
|
||||||
|
s.probeWg.Add(1)
|
||||||
|
defer s.probeWg.Done()
|
||||||
|
|
||||||
|
// Snapshot handlers under s.mux to avoid racing with updateMux/dnsMuxMap writers.
|
||||||
|
s.mux.Lock()
|
||||||
|
handlers := make([]handlerWithStop, 0, len(s.dnsMuxMap))
|
||||||
|
for _, mux := range s.dnsMuxMap {
|
||||||
|
handlers = append(handlers, mux.handler)
|
||||||
|
}
|
||||||
|
s.mux.Unlock()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
for _, handler := range handlers {
|
||||||
|
wg.Add(1)
|
||||||
|
go func(h handlerWithStop) {
|
||||||
|
defer wg.Done()
|
||||||
|
h.ProbeAvailability(probeCtx)
|
||||||
|
}(handler)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.probeMu.Unlock()
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
probeCancel()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
func (s *DefaultServer) UpdateServerConfig(domains dnsconfig.ServerDomains) error {
|
||||||
@@ -508,6 +619,7 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
s.currentConfig.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always apply host config for management updates, regardless of batch mode
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
|
|
||||||
s.shutdownWg.Add(1)
|
s.shutdownWg.Add(1)
|
||||||
@@ -641,6 +753,7 @@ func (s *DefaultServer) registerFallback(config HostDNSConfig) {
|
|||||||
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
|
log.Errorf("failed to create upstream resolver for original nameservers: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handler.routeMatch = s.routeMatch
|
||||||
|
|
||||||
for _, ns := range originalNameservers {
|
for _, ns := range originalNameservers {
|
||||||
if ns == config.ServerIP {
|
if ns == config.ServerIP {
|
||||||
@@ -750,6 +863,7 @@ func (s *DefaultServer) createHandlersForDomainGroup(domainGroup nsGroupsByDomai
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create upstream resolver: %v", err)
|
return nil, fmt.Errorf("create upstream resolver: %v", err)
|
||||||
}
|
}
|
||||||
|
handler.routeMatch = s.routeMatch
|
||||||
|
|
||||||
for _, ns := range nsGroup.NameServers {
|
for _, ns := range nsGroup.NameServers {
|
||||||
if ns.NSType != nbdns.UDPNameServerType {
|
if ns.NSType != nbdns.UDPNameServerType {
|
||||||
@@ -872,6 +986,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always apply host config when nameserver goes down, regardless of batch mode
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -907,6 +1022,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Always apply host config when nameserver reactivates, regardless of batch mode
|
||||||
s.applyHostConfig()
|
s.applyHostConfig()
|
||||||
|
|
||||||
s.updateNSState(nsGroup, nil, true)
|
s.updateNSState(nsGroup, nil, true)
|
||||||
@@ -932,6 +1048,7 @@ func (s *DefaultServer) addHostRootZone() {
|
|||||||
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
log.Errorf("unable to create a new upstream resolver, error: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
handler.routeMatch = s.routeMatch
|
||||||
|
|
||||||
handler.upstreamServers = maps.Keys(hostDNSServers)
|
handler.upstreamServers = maps.Keys(hostDNSServers)
|
||||||
handler.deactivate = func(error) {}
|
handler.deactivate = func(error) {}
|
||||||
|
|||||||
@@ -18,7 +18,12 @@ func TestGetServerDns(t *testing.T) {
|
|||||||
t.Errorf("invalid dns server instance: %s", err)
|
t.Errorf("invalid dns server instance: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if srvB != srv {
|
mockSrvB, ok := srvB.(*MockServer)
|
||||||
|
if !ok {
|
||||||
|
t.Errorf("returned server is not a MockServer")
|
||||||
|
}
|
||||||
|
|
||||||
|
if mockSrvB != srv {
|
||||||
t.Errorf("mismatch dns instances")
|
t.Errorf("mismatch dns instances")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1065,7 +1065,7 @@ type mockHandler struct {
|
|||||||
|
|
||||||
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
func (m *mockHandler) ServeDNS(dns.ResponseWriter, *dns.Msg) {}
|
||||||
func (m *mockHandler) Stop() {}
|
func (m *mockHandler) Stop() {}
|
||||||
func (m *mockHandler) ProbeAvailability() {}
|
func (m *mockHandler) ProbeAvailability(context.Context) {}
|
||||||
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
func (m *mockHandler) ID() types.HandlerID { return types.HandlerID(m.Id) }
|
||||||
|
|
||||||
type mockService struct{}
|
type mockService struct{}
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strconv"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -69,7 +70,7 @@ func (s *serviceViaListener) Listen() error {
|
|||||||
return fmt.Errorf("eval listen address: %w", err)
|
return fmt.Errorf("eval listen address: %w", err)
|
||||||
}
|
}
|
||||||
s.listenIP = s.listenIP.Unmap()
|
s.listenIP = s.listenIP.Unmap()
|
||||||
s.server.Addr = fmt.Sprintf("%s:%d", s.listenIP, s.listenPort)
|
s.server.Addr = net.JoinHostPort(s.listenIP.String(), strconv.Itoa(int(s.listenPort)))
|
||||||
log.Debugf("starting dns on %s", s.server.Addr)
|
log.Debugf("starting dns on %s", s.server.Addr)
|
||||||
go func() {
|
go func() {
|
||||||
s.setListenerStatus(true)
|
s.setListenerStatus(true)
|
||||||
@@ -186,7 +187,7 @@ func (s *serviceViaListener) testFreePort(port int) (netip.Addr, bool) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
func (s *serviceViaListener) tryToBind(ip netip.Addr, port int) bool {
|
||||||
addrString := fmt.Sprintf("%s:%d", ip, port)
|
addrString := net.JoinHostPort(ip.String(), strconv.Itoa(port))
|
||||||
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
udpAddr := net.UDPAddrFromAddrPort(netip.MustParseAddrPort(addrString))
|
||||||
probeListener, err := net.ListenUDP("udp", udpAddr)
|
probeListener, err := net.ListenUDP("udp", udpAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -65,10 +65,12 @@ type upstreamResolverBase struct {
|
|||||||
mutex sync.Mutex
|
mutex sync.Mutex
|
||||||
reactivatePeriod time.Duration
|
reactivatePeriod time.Duration
|
||||||
upstreamTimeout time.Duration
|
upstreamTimeout time.Duration
|
||||||
|
wg sync.WaitGroup
|
||||||
|
|
||||||
deactivate func(error)
|
deactivate func(error)
|
||||||
reactivate func()
|
reactivate func()
|
||||||
statusRecorder *peer.Status
|
statusRecorder *peer.Status
|
||||||
|
routeMatch func(netip.Addr) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
type upstreamFailure struct {
|
type upstreamFailure struct {
|
||||||
@@ -115,6 +117,11 @@ func (u *upstreamResolverBase) MatchSubdomains() bool {
|
|||||||
func (u *upstreamResolverBase) Stop() {
|
func (u *upstreamResolverBase) Stop() {
|
||||||
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
log.Debugf("stopping serving DNS for upstreams %s", u.upstreamServers)
|
||||||
u.cancel()
|
u.cancel()
|
||||||
|
|
||||||
|
u.mutex.Lock()
|
||||||
|
u.wg.Wait()
|
||||||
|
u.mutex.Unlock()
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// ServeDNS handles a DNS request
|
// ServeDNS handles a DNS request
|
||||||
@@ -260,16 +267,10 @@ func formatFailures(failures []upstreamFailure) string {
|
|||||||
|
|
||||||
// ProbeAvailability tests all upstream servers simultaneously and
|
// ProbeAvailability tests all upstream servers simultaneously and
|
||||||
// disables the resolver if none work
|
// disables the resolver if none work
|
||||||
func (u *upstreamResolverBase) ProbeAvailability() {
|
func (u *upstreamResolverBase) ProbeAvailability(ctx context.Context) {
|
||||||
u.mutex.Lock()
|
u.mutex.Lock()
|
||||||
defer u.mutex.Unlock()
|
defer u.mutex.Unlock()
|
||||||
|
|
||||||
select {
|
|
||||||
case <-u.ctx.Done():
|
|
||||||
return
|
|
||||||
default:
|
|
||||||
}
|
|
||||||
|
|
||||||
// avoid probe if upstreams could resolve at least one query
|
// avoid probe if upstreams could resolve at least one query
|
||||||
if u.successCount.Load() > 0 {
|
if u.successCount.Load() > 0 {
|
||||||
return
|
return
|
||||||
@@ -279,31 +280,39 @@ func (u *upstreamResolverBase) ProbeAvailability() {
|
|||||||
var mu sync.Mutex
|
var mu sync.Mutex
|
||||||
var wg sync.WaitGroup
|
var wg sync.WaitGroup
|
||||||
|
|
||||||
var errors *multierror.Error
|
var errs *multierror.Error
|
||||||
for _, upstream := range u.upstreamServers {
|
for _, upstream := range u.upstreamServers {
|
||||||
upstream := upstream
|
|
||||||
|
|
||||||
wg.Add(1)
|
wg.Add(1)
|
||||||
go func() {
|
go func(upstream netip.AddrPort) {
|
||||||
defer wg.Done()
|
defer wg.Done()
|
||||||
err := u.testNameserver(upstream, 500*time.Millisecond)
|
err := u.testNameserver(u.ctx, ctx, upstream, 500*time.Millisecond)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
errors = multierror.Append(errors, err)
|
mu.Lock()
|
||||||
|
errs = multierror.Append(errs, err)
|
||||||
|
mu.Unlock()
|
||||||
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
log.Warnf("probing upstream nameserver %s: %s", upstream, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
mu.Lock()
|
mu.Lock()
|
||||||
defer mu.Unlock()
|
|
||||||
success = true
|
success = true
|
||||||
}()
|
mu.Unlock()
|
||||||
|
}(upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
wg.Wait()
|
wg.Wait()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-u.ctx.Done():
|
||||||
|
return
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
|
||||||
// didn't find a working upstream server, let's disable and try later
|
// didn't find a working upstream server, let's disable and try later
|
||||||
if !success {
|
if !success {
|
||||||
u.disable(errors.ErrorOrNil())
|
u.disable(errs.ErrorOrNil())
|
||||||
|
|
||||||
if u.statusRecorder == nil {
|
if u.statusRecorder == nil {
|
||||||
return
|
return
|
||||||
@@ -339,7 +348,7 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, upstream := range u.upstreamServers {
|
for _, upstream := range u.upstreamServers {
|
||||||
if err := u.testNameserver(upstream, probeTimeout); err != nil {
|
if err := u.testNameserver(u.ctx, nil, upstream, probeTimeout); err != nil {
|
||||||
log.Tracef("upstream check for %s: %s", upstream, err)
|
log.Tracef("upstream check for %s: %s", upstream, err)
|
||||||
} else {
|
} else {
|
||||||
// at least one upstream server is available, stop probing
|
// at least one upstream server is available, stop probing
|
||||||
@@ -351,16 +360,22 @@ func (u *upstreamResolverBase) waitUntilResponse() {
|
|||||||
return fmt.Errorf("upstream check call error")
|
return fmt.Errorf("upstream check call error")
|
||||||
}
|
}
|
||||||
|
|
||||||
err := backoff.Retry(operation, exponentialBackOff)
|
err := backoff.Retry(operation, backoff.WithContext(exponentialBackOff, u.ctx))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err)
|
if errors.Is(err, context.Canceled) {
|
||||||
|
log.Debugf("upstream retry loop exited for upstreams %s", u.upstreamServersString())
|
||||||
|
} else {
|
||||||
|
log.Warnf("upstream retry loop exited for upstreams %s: %v", u.upstreamServersString(), err)
|
||||||
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
log.Infof("upstreams %s are responsive again. Adding them back to system", u.upstreamServersString())
|
||||||
u.successCount.Add(1)
|
u.successCount.Add(1)
|
||||||
u.reactivate()
|
u.reactivate()
|
||||||
|
u.mutex.Lock()
|
||||||
u.disabled = false
|
u.disabled = false
|
||||||
|
u.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// isTimeout returns true if the given error is a network timeout error.
|
// isTimeout returns true if the given error is a network timeout error.
|
||||||
@@ -383,7 +398,11 @@ func (u *upstreamResolverBase) disable(err error) {
|
|||||||
u.successCount.Store(0)
|
u.successCount.Store(0)
|
||||||
u.deactivate(err)
|
u.deactivate(err)
|
||||||
u.disabled = true
|
u.disabled = true
|
||||||
go u.waitUntilResponse()
|
u.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer u.wg.Done()
|
||||||
|
u.waitUntilResponse()
|
||||||
|
}()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) upstreamServersString() string {
|
func (u *upstreamResolverBase) upstreamServersString() string {
|
||||||
@@ -394,13 +413,18 @@ func (u *upstreamResolverBase) upstreamServersString() string {
|
|||||||
return strings.Join(servers, ", ")
|
return strings.Join(servers, ", ")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolverBase) testNameserver(server netip.AddrPort, timeout time.Duration) error {
|
func (u *upstreamResolverBase) testNameserver(baseCtx context.Context, externalCtx context.Context, server netip.AddrPort, timeout time.Duration) error {
|
||||||
ctx, cancel := context.WithTimeout(u.ctx, timeout)
|
mergedCtx, cancel := context.WithTimeout(baseCtx, timeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
|
if externalCtx != nil {
|
||||||
|
stop2 := context.AfterFunc(externalCtx, cancel)
|
||||||
|
defer stop2()
|
||||||
|
}
|
||||||
|
|
||||||
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
r := new(dns.Msg).SetQuestion(testRecord, dns.TypeSOA)
|
||||||
|
|
||||||
_, _, err := u.upstreamClient.exchange(ctx, server.String(), r)
|
_, _, err := u.upstreamClient.exchange(mergedCtx, server.String(), r)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -65,11 +65,13 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
} else {
|
} else {
|
||||||
upstreamIP = upstreamIP.Unmap()
|
upstreamIP = upstreamIP.Unmap()
|
||||||
}
|
}
|
||||||
if u.lNet.Contains(upstreamIP) || upstreamIP.IsPrivate() {
|
needsPrivate := u.lNet.Contains(upstreamIP) ||
|
||||||
log.Debugf("using private client to query upstream: %s", upstream)
|
(u.routeMatch != nil && u.routeMatch(upstreamIP))
|
||||||
|
if needsPrivate {
|
||||||
|
log.Debugf("using private client to query %s via upstream %s", r.Question[0].Name, upstream)
|
||||||
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
client, err = GetClientPrivate(u.lIP, u.interfaceName, timeout)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, 0, fmt.Errorf("error while creating private client: %s", err)
|
return nil, 0, fmt.Errorf("create private client: %s", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -188,7 +188,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|||||||
reactivated = true
|
reactivated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
resolver.ProbeAvailability()
|
resolver.ProbeAvailability(context.TODO())
|
||||||
|
|
||||||
if !failed {
|
if !failed {
|
||||||
t.Errorf("expected that resolving was deactivated")
|
t.Errorf("expected that resolving was deactivated")
|
||||||
|
|||||||
@@ -190,50 +190,75 @@ func (f *DNSForwarder) Close(ctx context.Context) error {
|
|||||||
return nberrors.FormatErrorOrNil(result)
|
return nberrors.FormatErrorOrNil(result)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg) *dns.Msg {
|
func (f *DNSForwarder) handleDNSQuery(logger *log.Entry, w dns.ResponseWriter, query *dns.Msg, startTime time.Time) {
|
||||||
if len(query.Question) == 0 {
|
if len(query.Question) == 0 {
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
question := query.Question[0]
|
question := query.Question[0]
|
||||||
logger.Tracef("received DNS request for DNS forwarder: domain=%s type=%s class=%s",
|
qname := strings.ToLower(question.Name)
|
||||||
question.Name, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
|
||||||
|
|
||||||
domain := strings.ToLower(question.Name)
|
logger.Tracef("question: domain=%s type=%s class=%s",
|
||||||
|
qname, dns.TypeToString[question.Qtype], dns.ClassToString[question.Qclass])
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
network := resutil.NetworkForQtype(question.Qtype)
|
network := resutil.NetworkForQtype(question.Qtype)
|
||||||
if network == "" {
|
if network == "" {
|
||||||
resp.Rcode = dns.RcodeNotImplemented
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(domain, "."))
|
mostSpecificResId, matchingEntries := f.getMatchingEntries(strings.TrimSuffix(qname, "."))
|
||||||
// query doesn't match any configured domain
|
|
||||||
if mostSpecificResId == "" {
|
if mostSpecificResId == "" {
|
||||||
resp.Rcode = dns.RcodeRefused
|
resp.Rcode = dns.RcodeRefused
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
return
|
||||||
}
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
result := resutil.LookupIP(ctx, f.resolver, network, domain, question.Qtype)
|
result := resutil.LookupIP(ctx, f.resolver, network, qname, question.Qtype)
|
||||||
if result.Err != nil {
|
if result.Err != nil {
|
||||||
f.handleDNSError(ctx, logger, w, question, resp, domain, result)
|
f.handleDNSError(ctx, logger, w, question, resp, qname, result, startTime)
|
||||||
return nil
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
f.updateInternalState(result.IPs, mostSpecificResId, matchingEntries)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, result.IPs, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(qname, result.IPs, f.ttl)...)
|
||||||
f.cache.set(domain, question.Qtype, result.IPs)
|
f.cache.set(qname, question.Qtype, result.IPs)
|
||||||
|
|
||||||
return resp
|
f.writeResponse(logger, w, resp, qname, startTime)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *DNSForwarder) writeResponse(logger *log.Entry, w dns.ResponseWriter, resp *dns.Msg, qname string, startTime time.Time) {
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
logger.Errorf("failed to write DNS response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
||||||
|
qname, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpResponseWriter wraps a dns.ResponseWriter to handle UDP-specific truncation.
|
||||||
|
type udpResponseWriter struct {
|
||||||
|
dns.ResponseWriter
|
||||||
|
query *dns.Msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (u *udpResponseWriter) WriteMsg(resp *dns.Msg) error {
|
||||||
|
opt := u.query.IsEdns0()
|
||||||
|
maxSize := dns.MinMsgSize
|
||||||
|
if opt != nil {
|
||||||
|
maxSize = int(opt.UDPSize())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.Len() > maxSize {
|
||||||
|
resp.Truncate(maxSize)
|
||||||
|
}
|
||||||
|
|
||||||
|
return u.ResponseWriter.WriteMsg(resp)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -243,30 +268,7 @@ func (f *DNSForwarder) handleDNSQueryUDP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, &udpResponseWriter{ResponseWriter: w, query: query}, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
opt := query.IsEdns0()
|
|
||||||
maxSize := dns.MinMsgSize
|
|
||||||
if opt != nil {
|
|
||||||
// client advertised a larger EDNS0 buffer
|
|
||||||
maxSize = int(opt.UDPSize())
|
|
||||||
}
|
|
||||||
|
|
||||||
// if our response is too big, truncate and set the TC bit
|
|
||||||
if resp.Len() > maxSize {
|
|
||||||
resp.Truncate(maxSize)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
||||||
@@ -276,18 +278,7 @@ func (f *DNSForwarder) handleDNSQueryTCP(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
"dns_id": fmt.Sprintf("%04x", query.Id),
|
"dns_id": fmt.Sprintf("%04x", query.Id),
|
||||||
})
|
})
|
||||||
|
|
||||||
resp := f.handleDNSQuery(logger, w, query)
|
f.handleDNSQuery(logger, w, query, startTime)
|
||||||
if resp == nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
|
||||||
logger.Errorf("failed to write DNS response: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
logger.Tracef("response: domain=%s rcode=%s answers=%s took=%s",
|
|
||||||
query.Question[0].Name, dns.RcodeToString[resp.Rcode], resutil.FormatAnswers(resp.Answer), time.Since(startTime))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
func (f *DNSForwarder) updateInternalState(ips []netip.Addr, mostSpecificResId route.ResID, matchingEntries []*ForwarderEntry) {
|
||||||
@@ -334,6 +325,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
resp *dns.Msg,
|
resp *dns.Msg,
|
||||||
domain string,
|
domain string,
|
||||||
result resutil.LookupResult,
|
result resutil.LookupResult,
|
||||||
|
startTime time.Time,
|
||||||
) {
|
) {
|
||||||
qType := question.Qtype
|
qType := question.Qtype
|
||||||
qTypeName := dns.TypeToString[qType]
|
qTypeName := dns.TypeToString[qType]
|
||||||
@@ -343,9 +335,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// NotFound: cache negative result and respond
|
// NotFound: cache negative result and respond
|
||||||
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
if result.Rcode == dns.RcodeNameError || result.Rcode == dns.RcodeSuccess {
|
||||||
f.cache.set(domain, question.Qtype, nil)
|
f.cache.set(domain, question.Qtype, nil)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -355,9 +345,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
logger.Debugf("serving cached DNS response after upstream failure: domain=%s type=%s", domain, qTypeName)
|
||||||
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
resp.Answer = append(resp.Answer, resutil.IPsToRRs(domain, ips, f.ttl)...)
|
||||||
resp.Rcode = dns.RcodeSuccess
|
resp.Rcode = dns.RcodeSuccess
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write cached DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -365,9 +353,7 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
verifyResult := resutil.LookupIP(ctx, f.resolver, resutil.NetworkForQtype(qType), domain, qType)
|
||||||
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
if verifyResult.Rcode == dns.RcodeNameError || verifyResult.Rcode == dns.RcodeSuccess {
|
||||||
resp.Rcode = verifyResult.Rcode
|
resp.Rcode = verifyResult.Rcode
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -375,15 +361,12 @@ func (f *DNSForwarder) handleDNSError(
|
|||||||
// No cache or verification failed. Log with or without the server field for more context.
|
// No cache or verification failed. Log with or without the server field for more context.
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
if errors.As(result.Err, &dnsErr) && dnsErr.Server != "" {
|
||||||
logger.Warnf("failed to resolve: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
logger.Warnf("upstream failure: type=%s domain=%s server=%s: %v", qTypeName, domain, dnsErr.Server, result.Err)
|
||||||
} else {
|
} else {
|
||||||
logger.Warnf(errResolveFailed, domain, result.Err)
|
logger.Warnf(errResolveFailed, domain, result.Err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Write final failure response.
|
f.writeResponse(logger, w, resp, domain, startTime)
|
||||||
if writeErr := w.WriteMsg(resp); writeErr != nil {
|
|
||||||
logger.Errorf("failed to write failure DNS response: %v", writeErr)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// getMatchingEntries retrieves the resource IDs for a given domain.
|
// getMatchingEntries retrieves the resource IDs for a given domain.
|
||||||
|
|||||||
@@ -318,8 +318,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
query.SetQuestion(dns.Fqdn(tt.queryDomain), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode, "Expected successful response")
|
||||||
@@ -329,10 +330,9 @@ func TestDNSForwarder_UnauthorizedDomainAccess(t *testing.T) {
|
|||||||
mockFirewall.AssertExpectations(t)
|
mockFirewall.AssertExpectations(t)
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
} else {
|
} else {
|
||||||
if resp != nil {
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
assert.True(t, len(resp.Answer) == 0 || resp.Rcode != dns.RcodeSuccess,
|
||||||
"Unauthorized domain should not return successful answers")
|
"Unauthorized domain should not return successful answers")
|
||||||
}
|
|
||||||
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
mockFirewall.AssertNotCalled(t, "UpdateSet")
|
||||||
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
mockResolver.AssertNotCalled(t, "LookupNetIP")
|
||||||
}
|
}
|
||||||
@@ -466,14 +466,16 @@ func TestDNSForwarder_FirewallSetUpdates(t *testing.T) {
|
|||||||
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
dnsQuery.SetQuestion(dns.Fqdn(tt.query), dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, dnsQuery, time.Now())
|
||||||
|
|
||||||
// Verify response
|
// Verify response
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
if tt.shouldResolve {
|
if tt.shouldResolve {
|
||||||
require.NotNil(t, resp, "Expected response for authorized domain")
|
require.NotNil(t, resp, "Expected response for authorized domain")
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.NotEmpty(t, resp.Answer)
|
require.NotEmpty(t, resp.Answer)
|
||||||
} else if resp != nil {
|
} else {
|
||||||
|
require.NotNil(t, resp, "Expected response")
|
||||||
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
assert.True(t, resp.Rcode == dns.RcodeRefused || len(resp.Answer) == 0,
|
||||||
"Unauthorized domain should be refused or have no answers")
|
"Unauthorized domain should be refused or have no answers")
|
||||||
}
|
}
|
||||||
@@ -528,9 +530,10 @@ func TestDNSForwarder_MultipleIPsInSingleUpdate(t *testing.T) {
|
|||||||
query.SetQuestion("example.com.", dns.TypeA)
|
query.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Verify response contains all IPs
|
// Verify response contains all IPs
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
require.Len(t, resp.Answer, 3, "Should have 3 answer records")
|
||||||
@@ -605,7 +608,7 @@ func TestDNSForwarder_ResponseCodes(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
// Check the response written to the writer
|
// Check the response written to the writer
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
require.NotNil(t, writtenResp, "Expected response to be written")
|
||||||
@@ -675,7 +678,8 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q1.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -683,13 +687,13 @@ func TestDNSForwarder_ServeFromCacheOnUpstreamFailure(t *testing.T) {
|
|||||||
// Second query: serve from cache after upstream failure
|
// Second query: serve from cache after upstream failure
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
q2.SetQuestion(dns.Fqdn("example.com"), dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "expected response to be written")
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2, "expected response to be written")
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -715,7 +719,8 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
q1 := &dns.Msg{}
|
q1 := &dns.Msg{}
|
||||||
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
q1.SetQuestion(mixedQuery+".", dns.TypeA)
|
||||||
w1 := &test.MockResponseWriter{}
|
w1 := &test.MockResponseWriter{}
|
||||||
resp1 := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w1, q1, time.Now())
|
||||||
|
resp1 := w1.GetLastResponse()
|
||||||
require.NotNil(t, resp1)
|
require.NotNil(t, resp1)
|
||||||
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
require.Equal(t, dns.RcodeSuccess, resp1.Rcode)
|
||||||
require.Len(t, resp1.Answer, 1)
|
require.Len(t, resp1.Answer, 1)
|
||||||
@@ -727,13 +732,13 @@ func TestDNSForwarder_CacheNormalizationCasingAndDot(t *testing.T) {
|
|||||||
|
|
||||||
q2 := &dns.Msg{}
|
q2 := &dns.Msg{}
|
||||||
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
q2.SetQuestion("EXAMPLE.COM", dns.TypeA)
|
||||||
var writtenResp *dns.Msg
|
w2 := &test.MockResponseWriter{}
|
||||||
w2 := &test.MockResponseWriter{WriteMsgFunc: func(m *dns.Msg) error { writtenResp = m; return nil }}
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2, time.Now())
|
||||||
_ = forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), w2, q2)
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp)
|
resp2 := w2.GetLastResponse()
|
||||||
require.Equal(t, dns.RcodeSuccess, writtenResp.Rcode)
|
require.NotNil(t, resp2)
|
||||||
require.Len(t, writtenResp.Answer, 1)
|
require.Equal(t, dns.RcodeSuccess, resp2.Rcode)
|
||||||
|
require.Len(t, resp2.Answer, 1)
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
@@ -784,8 +789,9 @@ func TestDNSForwarder_MultipleOverlappingPatterns(t *testing.T) {
|
|||||||
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
query.SetQuestion("smtp.mail.example.com.", dns.TypeA)
|
||||||
|
|
||||||
mockWriter := &test.MockResponseWriter{}
|
mockWriter := &test.MockResponseWriter{}
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
|
|
||||||
|
resp := mockWriter.GetLastResponse()
|
||||||
require.NotNil(t, resp)
|
require.NotNil(t, resp)
|
||||||
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
assert.Equal(t, dns.RcodeSuccess, resp.Rcode)
|
||||||
|
|
||||||
@@ -897,26 +903,15 @@ func TestDNSForwarder_NodataVsNxdomain(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
query.SetQuestion(dns.Fqdn("example.com"), tt.queryType)
|
||||||
|
|
||||||
var writtenResp *dns.Msg
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writtenResp = m
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
resp := mockWriter.GetLastResponse()
|
||||||
|
require.NotNil(t, resp, "Expected response to be written")
|
||||||
// If a response was returned, it means it should be written (happens in wrapper functions)
|
assert.Equal(t, tt.expectedCode, resp.Rcode, tt.description)
|
||||||
if resp != nil && writtenResp == nil {
|
|
||||||
writtenResp = resp
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NotNil(t, writtenResp, "Expected response to be written")
|
|
||||||
assert.Equal(t, tt.expectedCode, writtenResp.Rcode, tt.description)
|
|
||||||
|
|
||||||
if tt.expectNoAnswer {
|
if tt.expectNoAnswer {
|
||||||
assert.Empty(t, writtenResp.Answer, "Response should have no answer records")
|
assert.Empty(t, resp.Answer, "Response should have no answer records")
|
||||||
}
|
}
|
||||||
|
|
||||||
mockResolver.AssertExpectations(t)
|
mockResolver.AssertExpectations(t)
|
||||||
@@ -931,15 +926,8 @@ func TestDNSForwarder_EmptyQuery(t *testing.T) {
|
|||||||
query := &dns.Msg{}
|
query := &dns.Msg{}
|
||||||
// Don't set any question
|
// Don't set any question
|
||||||
|
|
||||||
writeCalled := false
|
mockWriter := &test.MockResponseWriter{}
|
||||||
mockWriter := &test.MockResponseWriter{
|
forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query, time.Now())
|
||||||
WriteMsgFunc: func(m *dns.Msg) error {
|
|
||||||
writeCalled = true
|
|
||||||
return nil
|
|
||||||
},
|
|
||||||
}
|
|
||||||
resp := forwarder.handleDNSQuery(log.NewEntry(log.StandardLogger()), mockWriter, query)
|
|
||||||
|
|
||||||
assert.Nil(t, resp, "Should return nil for empty query")
|
assert.Nil(t, mockWriter.GetLastResponse(), "Should not write response for empty query")
|
||||||
assert.False(t, writeCalled, "Should not write response for empty query")
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -29,13 +29,16 @@ import (
|
|||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
nbnetstack "github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/udpmux"
|
"github.com/netbirdio/netbird/client/iface/udpmux"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl"
|
"github.com/netbirdio/netbird/client/internal/acl"
|
||||||
"github.com/netbirdio/netbird/client/internal/debug"
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
dnsconfig "github.com/netbirdio/netbird/client/internal/dns/config"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/expose"
|
||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/metrics"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
@@ -49,16 +52,14 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/routemanager"
|
"github.com/netbirdio/netbird/client/internal/routemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/updatemanager"
|
"github.com/netbirdio/netbird/client/internal/updater"
|
||||||
"github.com/netbirdio/netbird/client/jobexec"
|
"github.com/netbirdio/netbird/client/jobexec"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/shared/management/domain"
|
|
||||||
semaphoregroup "github.com/netbirdio/netbird/util/semaphore-group"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
mgm "github.com/netbirdio/netbird/shared/management/client"
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
"github.com/netbirdio/netbird/shared/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
mgmProto "github.com/netbirdio/netbird/shared/management/proto"
|
||||||
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
auth "github.com/netbirdio/netbird/shared/relay/auth/hmac"
|
||||||
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
relayClient "github.com/netbirdio/netbird/shared/relay/client"
|
||||||
@@ -74,13 +75,11 @@ import (
|
|||||||
const (
|
const (
|
||||||
PeerConnectionTimeoutMax = 45000 // ms
|
PeerConnectionTimeoutMax = 45000 // ms
|
||||||
PeerConnectionTimeoutMin = 30000 // ms
|
PeerConnectionTimeoutMin = 30000 // ms
|
||||||
connInitLimit = 200
|
|
||||||
disableAutoUpdate = "disabled"
|
disableAutoUpdate = "disabled"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrResetConnection = fmt.Errorf("reset connection")
|
var ErrResetConnection = fmt.Errorf("reset connection")
|
||||||
|
|
||||||
// EngineConfig is a config for the Engine
|
|
||||||
type EngineConfig struct {
|
type EngineConfig struct {
|
||||||
WgPort int
|
WgPort int
|
||||||
WgIfaceName string
|
WgIfaceName string
|
||||||
@@ -142,6 +141,18 @@ type EngineConfig struct {
|
|||||||
LogPath string
|
LogPath string
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// EngineServices holds the external service dependencies required by the Engine.
|
||||||
|
type EngineServices struct {
|
||||||
|
SignalClient signal.Client
|
||||||
|
MgmClient mgm.Client
|
||||||
|
RelayManager *relayClient.Manager
|
||||||
|
StatusRecorder *peer.Status
|
||||||
|
Checks []*mgmProto.Checks
|
||||||
|
StateManager *statemanager.Manager
|
||||||
|
UpdateManager *updater.Manager
|
||||||
|
ClientMetrics *metrics.ClientMetrics
|
||||||
|
}
|
||||||
|
|
||||||
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
// Engine is a mechanism responsible for reacting on Signal and Management stream events and managing connections to the remote peers.
|
||||||
type Engine struct {
|
type Engine struct {
|
||||||
// signal is a Signal Service client
|
// signal is a Signal Service client
|
||||||
@@ -207,11 +218,10 @@ type Engine struct {
|
|||||||
syncRespMux sync.RWMutex
|
syncRespMux sync.RWMutex
|
||||||
persistSyncResponse bool
|
persistSyncResponse bool
|
||||||
latestSyncResponse *mgmProto.SyncResponse
|
latestSyncResponse *mgmProto.SyncResponse
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
|
||||||
flowManager nftypes.FlowManager
|
flowManager nftypes.FlowManager
|
||||||
|
|
||||||
// auto-update
|
// auto-update
|
||||||
updateManager *updatemanager.Manager
|
updateManager *updater.Manager
|
||||||
|
|
||||||
// WireGuard interface monitor
|
// WireGuard interface monitor
|
||||||
wgIfaceMonitor *WGIfaceMonitor
|
wgIfaceMonitor *WGIfaceMonitor
|
||||||
@@ -221,8 +231,13 @@ type Engine struct {
|
|||||||
|
|
||||||
probeStunTurn *relay.StunTurnProbe
|
probeStunTurn *relay.StunTurnProbe
|
||||||
|
|
||||||
|
// clientMetrics collects and pushes metrics
|
||||||
|
clientMetrics *metrics.ClientMetrics
|
||||||
|
|
||||||
jobExecutor *jobexec.Executor
|
jobExecutor *jobexec.Executor
|
||||||
jobExecutorWG sync.WaitGroup
|
jobExecutorWG sync.WaitGroup
|
||||||
|
|
||||||
|
exposeManager *expose.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -239,22 +254,17 @@ type localIpUpdater interface {
|
|||||||
func NewEngine(
|
func NewEngine(
|
||||||
clientCtx context.Context,
|
clientCtx context.Context,
|
||||||
clientCancel context.CancelFunc,
|
clientCancel context.CancelFunc,
|
||||||
signalClient signal.Client,
|
|
||||||
mgmClient mgm.Client,
|
|
||||||
relayManager *relayClient.Manager,
|
|
||||||
config *EngineConfig,
|
config *EngineConfig,
|
||||||
|
services EngineServices,
|
||||||
mobileDep MobileDependency,
|
mobileDep MobileDependency,
|
||||||
statusRecorder *peer.Status,
|
|
||||||
checks []*mgmProto.Checks,
|
|
||||||
stateManager *statemanager.Manager,
|
|
||||||
) *Engine {
|
) *Engine {
|
||||||
engine := &Engine{
|
engine := &Engine{
|
||||||
clientCtx: clientCtx,
|
clientCtx: clientCtx,
|
||||||
clientCancel: clientCancel,
|
clientCancel: clientCancel,
|
||||||
signal: signalClient,
|
signal: services.SignalClient,
|
||||||
signaler: peer.NewSignaler(signalClient, config.WgPrivateKey),
|
signaler: peer.NewSignaler(services.SignalClient, config.WgPrivateKey),
|
||||||
mgmClient: mgmClient,
|
mgmClient: services.MgmClient,
|
||||||
relayManager: relayManager,
|
relayManager: services.RelayManager,
|
||||||
peerStore: peerstore.NewConnStore(),
|
peerStore: peerstore.NewConnStore(),
|
||||||
syncMsgMux: &sync.Mutex{},
|
syncMsgMux: &sync.Mutex{},
|
||||||
config: config,
|
config: config,
|
||||||
@@ -262,12 +272,13 @@ func NewEngine(
|
|||||||
STUNs: []*stun.URI{},
|
STUNs: []*stun.URI{},
|
||||||
TURNs: []*stun.URI{},
|
TURNs: []*stun.URI{},
|
||||||
networkSerial: 0,
|
networkSerial: 0,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: services.StatusRecorder,
|
||||||
stateManager: stateManager,
|
stateManager: services.StateManager,
|
||||||
checks: checks,
|
checks: services.Checks,
|
||||||
connSemaphore: semaphoregroup.NewSemaphoreGroup(connInitLimit),
|
|
||||||
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
probeStunTurn: relay.NewStunTurnProbe(relay.DefaultCacheTTL),
|
||||||
jobExecutor: jobexec.NewExecutor(),
|
jobExecutor: jobexec.NewExecutor(),
|
||||||
|
clientMetrics: services.ClientMetrics,
|
||||||
|
updateManager: services.UpdateManager,
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
log.Infof("I am: %s", config.WgPrivateKey.PublicKey().String())
|
||||||
@@ -310,7 +321,7 @@ func (e *Engine) Stop() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if e.updateManager != nil {
|
if e.updateManager != nil {
|
||||||
e.updateManager.Stop()
|
e.updateManager.SetDownloadOnly()
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Info("cleaning up status recorder states")
|
log.Info("cleaning up status recorder states")
|
||||||
@@ -418,6 +429,7 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
e.cancel()
|
e.cancel()
|
||||||
}
|
}
|
||||||
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
e.ctx, e.cancel = context.WithCancel(e.clientCtx)
|
||||||
|
e.exposeManager = expose.NewManager(e.ctx, e.mgmClient)
|
||||||
|
|
||||||
wgIface, err := e.newWgIface()
|
wgIface, err := e.newWgIface()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -487,6 +499,17 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
|
|
||||||
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
e.routeManager.SetRouteChangeListener(e.mobileDep.NetworkChangeListener)
|
||||||
|
|
||||||
|
e.dnsServer.SetRouteChecker(func(ip netip.Addr) bool {
|
||||||
|
for _, routes := range e.routeManager.GetClientRoutes() {
|
||||||
|
for _, r := range routes {
|
||||||
|
if r.Network.Contains(ip) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
})
|
||||||
|
|
||||||
if err = e.wgInterfaceCreate(); err != nil {
|
if err = e.wgInterfaceCreate(); err != nil {
|
||||||
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
log.Errorf("failed creating tunnel interface %s: [%s]", e.config.WgIfaceName, err.Error())
|
||||||
e.close()
|
e.close()
|
||||||
@@ -543,11 +566,12 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
// monitor WireGuard interface lifecycle and restart engine on changes
|
// monitor WireGuard interface lifecycle and restart engine on changes
|
||||||
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
e.wgIfaceMonitor = NewWGIfaceMonitor()
|
||||||
e.shutdownWg.Add(1)
|
e.shutdownWg.Add(1)
|
||||||
|
wgIfaceName := e.wgInterface.Name()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
|
|
||||||
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, e.wgInterface.Name()); shouldRestart {
|
if shouldRestart, err := e.wgIfaceMonitor.Start(e.ctx, wgIfaceName); shouldRestart {
|
||||||
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
log.Infof("WireGuard interface monitor: %s, restarting engine", err)
|
||||||
e.triggerClientRestart()
|
e.triggerClientRestart()
|
||||||
} else if err != nil {
|
} else if err != nil {
|
||||||
@@ -558,13 +582,6 @@ func (e *Engine) Start(netbirdConfig *mgmProto.NetbirdConfig, mgmtURL *url.URL)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) InitialUpdateHandling(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
|
||||||
e.syncMsgMux.Lock()
|
|
||||||
defer e.syncMsgMux.Unlock()
|
|
||||||
|
|
||||||
e.handleAutoUpdateVersion(autoUpdateSettings, true)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *Engine) createFirewall() error {
|
func (e *Engine) createFirewall() error {
|
||||||
if e.config.DisableFirewall {
|
if e.config.DisableFirewall {
|
||||||
log.Infof("firewall is disabled")
|
log.Infof("firewall is disabled")
|
||||||
@@ -792,42 +809,31 @@ func (e *Engine) PopulateNetbirdConfig(netbirdConfig *mgmProto.NetbirdConfig, mg
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings, initialCheck bool) {
|
func (e *Engine) handleAutoUpdateVersion(autoUpdateSettings *mgmProto.AutoUpdateSettings) {
|
||||||
|
if e.updateManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
if autoUpdateSettings == nil {
|
if autoUpdateSettings == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
disabled := autoUpdateSettings.Version == disableAutoUpdate
|
if autoUpdateSettings.Version == disableAutoUpdate {
|
||||||
|
log.Infof("auto-update is disabled")
|
||||||
// Stop and cleanup if disabled
|
e.updateManager.SetDownloadOnly()
|
||||||
if e.updateManager != nil && disabled {
|
|
||||||
log.Infof("auto-update is disabled, stopping update manager")
|
|
||||||
e.updateManager.Stop()
|
|
||||||
e.updateManager = nil
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
// Skip check unless AlwaysUpdate is enabled or this is the initial check at startup
|
e.updateManager.SetVersion(autoUpdateSettings.Version, autoUpdateSettings.AlwaysUpdate)
|
||||||
if !autoUpdateSettings.AlwaysUpdate && !initialCheck {
|
|
||||||
log.Debugf("skipping auto-update check, AlwaysUpdate is false and this is not the initial check")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
// Start manager if needed
|
|
||||||
if e.updateManager == nil {
|
|
||||||
log.Infof("starting auto-update manager")
|
|
||||||
updateManager, err := updatemanager.NewManager(e.statusRecorder, e.stateManager)
|
|
||||||
if err != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
e.updateManager = updateManager
|
|
||||||
e.updateManager.Start(e.ctx)
|
|
||||||
}
|
|
||||||
log.Infof("handling auto-update version: %s", autoUpdateSettings.Version)
|
|
||||||
e.updateManager.SetVersion(autoUpdateSettings.Version)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
||||||
|
started := time.Now()
|
||||||
|
defer func() {
|
||||||
|
duration := time.Since(started)
|
||||||
|
log.Infof("sync finished in %s", duration)
|
||||||
|
e.clientMetrics.RecordSyncDuration(e.ctx, duration)
|
||||||
|
}()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
|
||||||
@@ -837,7 +843,7 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
if update.NetworkMap != nil && update.NetworkMap.PeerConfig != nil {
|
||||||
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate, false)
|
e.handleAutoUpdateVersion(update.NetworkMap.PeerConfig.AutoUpdate)
|
||||||
}
|
}
|
||||||
|
|
||||||
if update.GetNetbirdConfig() != nil {
|
if update.GetNetbirdConfig() != nil {
|
||||||
@@ -1002,10 +1008,11 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
return errors.New("wireguard interface is not initialized")
|
return errors.New("wireguard interface is not initialized")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Cannot update the IP address without restarting the engine because
|
|
||||||
// the firewall, route manager, and other components cache the old address
|
|
||||||
if e.wgInterface.Address().String() != conf.Address {
|
if e.wgInterface.Address().String() != conf.Address {
|
||||||
log.Infof("peer IP address has changed from %s to %s", e.wgInterface.Address().String(), conf.Address)
|
log.Infof("peer IP address changed from %s to %s, restarting client", e.wgInterface.Address().String(), conf.Address)
|
||||||
|
_ = CtxGetState(e.ctx).Wrap(ErrResetConnection)
|
||||||
|
e.clientCancel()
|
||||||
|
return ErrResetConnection
|
||||||
}
|
}
|
||||||
|
|
||||||
if conf.GetSshConfig() != nil {
|
if conf.GetSshConfig() != nil {
|
||||||
@@ -1017,7 +1024,7 @@ func (e *Engine) updateConfig(conf *mgmProto.PeerConfig) error {
|
|||||||
state := e.statusRecorder.GetLocalPeerState()
|
state := e.statusRecorder.GetLocalPeerState()
|
||||||
state.IP = e.wgInterface.Address().String()
|
state.IP = e.wgInterface.Address().String()
|
||||||
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
state.PubKey = e.config.WgPrivateKey.PublicKey().String()
|
||||||
state.KernelInterface = device.WireGuardModuleIsLoaded()
|
state.KernelInterface = !e.wgInterface.IsUserspaceBind()
|
||||||
state.FQDN = conf.GetFqdn()
|
state.FQDN = conf.GetFqdn()
|
||||||
|
|
||||||
e.statusRecorder.UpdateLocalPeerState(state)
|
e.statusRecorder.UpdateLocalPeerState(state)
|
||||||
@@ -1073,6 +1080,7 @@ func (e *Engine) handleBundle(params *mgmProto.BundleParameters) (*mgmProto.JobR
|
|||||||
StatusRecorder: e.statusRecorder,
|
StatusRecorder: e.statusRecorder,
|
||||||
SyncResponse: syncResponse,
|
SyncResponse: syncResponse,
|
||||||
LogPath: e.config.LogPath,
|
LogPath: e.config.LogPath,
|
||||||
|
ClientMetrics: e.clientMetrics,
|
||||||
RefreshStatus: func() {
|
RefreshStatus: func() {
|
||||||
e.RunHealthProbes(true)
|
e.RunHealthProbes(true)
|
||||||
},
|
},
|
||||||
@@ -1310,8 +1318,7 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
// Test received (upstream) servers for availability right away instead of upon usage.
|
// Test received (upstream) servers for availability right away instead of upon usage.
|
||||||
// If no server of a server group responds this will disable the respective handler and retry later.
|
// If no server of a server group responds this will disable the respective handler and retry later.
|
||||||
e.dnsServer.ProbeAvailability()
|
go e.dnsServer.ProbeAvailability()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1528,12 +1535,12 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix, agentV
|
|||||||
}
|
}
|
||||||
|
|
||||||
serviceDependencies := peer.ServiceDependencies{
|
serviceDependencies := peer.ServiceDependencies{
|
||||||
StatusRecorder: e.statusRecorder,
|
StatusRecorder: e.statusRecorder,
|
||||||
Signaler: e.signaler,
|
Signaler: e.signaler,
|
||||||
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
IFaceDiscover: e.mobileDep.IFaceDiscover,
|
||||||
RelayManager: e.relayManager,
|
RelayManager: e.relayManager,
|
||||||
SrWatcher: e.srWatcher,
|
SrWatcher: e.srWatcher,
|
||||||
Semaphore: e.connSemaphore,
|
MetricsRecorder: e.clientMetrics,
|
||||||
}
|
}
|
||||||
peerConn, err := peer.NewConn(config, serviceDependencies)
|
peerConn, err := peer.NewConn(config, serviceDependencies)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1556,8 +1563,10 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
defer e.shutdownWg.Done()
|
defer e.shutdownWg.Done()
|
||||||
// connect to a stream of messages coming from the signal server
|
// connect to a stream of messages coming from the signal server
|
||||||
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
err := e.signal.Receive(e.ctx, func(msg *sProto.Message) error {
|
||||||
|
start := time.Now()
|
||||||
e.syncMsgMux.Lock()
|
e.syncMsgMux.Lock()
|
||||||
defer e.syncMsgMux.Unlock()
|
defer e.syncMsgMux.Unlock()
|
||||||
|
gotLock := time.Since(start)
|
||||||
|
|
||||||
// Check context INSIDE lock to ensure atomicity with shutdown
|
// Check context INSIDE lock to ensure atomicity with shutdown
|
||||||
if e.ctx.Err() != nil {
|
if e.ctx.Err() != nil {
|
||||||
@@ -1581,6 +1590,8 @@ func (e *Engine) receiveSignalEvents() {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("receiveMSG: took %s to get lock for peer %s with session id %s", gotLock, msg.Key, offerAnswer.SessionID)
|
||||||
|
|
||||||
if msg.Body.Type == sProto.Body_OFFER {
|
if msg.Body.Type == sProto.Body_OFFER {
|
||||||
conn.OnRemoteOffer(*offerAnswer)
|
conn.OnRemoteOffer(*offerAnswer)
|
||||||
} else {
|
} else {
|
||||||
@@ -1814,11 +1825,23 @@ func (e *Engine) GetRouteManager() routemanager.Manager {
|
|||||||
return e.routeManager
|
return e.routeManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetFirewallManager returns the firewall manager
|
// GetFirewallManager returns the firewall manager.
|
||||||
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
func (e *Engine) GetFirewallManager() firewallManager.Manager {
|
||||||
return e.firewall
|
return e.firewall
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetExposeManager returns the expose session manager.
|
||||||
|
func (e *Engine) GetExposeManager() *expose.Manager {
|
||||||
|
e.syncMsgMux.Lock()
|
||||||
|
defer e.syncMsgMux.Unlock()
|
||||||
|
return e.exposeManager
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetClientMetrics returns the client metrics
|
||||||
|
func (e *Engine) GetClientMetrics() *metrics.ClientMetrics {
|
||||||
|
return e.clientMetrics
|
||||||
|
}
|
||||||
|
|
||||||
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
func findIPFromInterfaceName(ifaceName string) (net.IP, error) {
|
||||||
iface, err := net.InterfaceByName(ifaceName)
|
iface, err := net.InterfaceByName(ifaceName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1918,7 +1941,7 @@ func (e *Engine) triggerClientRestart() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (e *Engine) startNetworkMonitor() {
|
func (e *Engine) startNetworkMonitor() {
|
||||||
if !e.config.NetworkMonitor {
|
if !e.config.NetworkMonitor || nbnetstack.IsEnabled() {
|
||||||
log.Infof("Network monitor is disabled, not starting")
|
log.Infof("Network monitor is disabled, not starting")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
firewallManager "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
sshauth "github.com/netbirdio/netbird/client/ssh/auth"
|
||||||
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
sshconfig "github.com/netbirdio/netbird/client/ssh/config"
|
||||||
@@ -94,6 +95,10 @@ func (e *Engine) updateSSH(sshConf *mgmProto.SSHConfig) error {
|
|||||||
|
|
||||||
// updateSSHClientConfig updates the SSH client configuration with peer information
|
// updateSSHClientConfig updates the SSH client configuration with peer information
|
||||||
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
func (e *Engine) updateSSHClientConfig(remotePeers []*mgmProto.RemotePeerConfig) error {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
peerInfo := e.extractPeerSSHInfo(remotePeers)
|
||||||
if len(peerInfo) == 0 {
|
if len(peerInfo) == 0 {
|
||||||
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
log.Debug("no SSH-enabled peers found, skipping SSH config update")
|
||||||
@@ -216,6 +221,10 @@ func (e *Engine) GetPeerSSHKey(peerAddress string) ([]byte, bool) {
|
|||||||
|
|
||||||
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
// cleanupSSHConfig removes NetBird SSH client configuration on shutdown
|
||||||
func (e *Engine) cleanupSSHConfig() {
|
func (e *Engine) cleanupSSHConfig() {
|
||||||
|
if netstack.IsEnabled() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
configMgr := sshconfig.New()
|
configMgr := sshconfig.New()
|
||||||
|
|
||||||
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
if err := configMgr.RemoveSSHClientConfig(); err != nil {
|
||||||
|
|||||||
@@ -251,9 +251,6 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
engine := NewEngine(
|
engine := NewEngine(
|
||||||
ctx, cancel,
|
ctx, cancel,
|
||||||
&signal.MockClient{},
|
|
||||||
&mgmt.MockClient{},
|
|
||||||
relayMgr,
|
|
||||||
&EngineConfig{
|
&EngineConfig{
|
||||||
WgIfaceName: "utun101",
|
WgIfaceName: "utun101",
|
||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
@@ -263,10 +260,13 @@ func TestEngine_SSH(t *testing.T) {
|
|||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
SSHKey: sshKey,
|
SSHKey: sshKey,
|
||||||
},
|
},
|
||||||
|
EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
},
|
||||||
MobileDependency{},
|
MobileDependency{},
|
||||||
peer.NewRecorder("https://mgm"),
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@@ -428,13 +428,18 @@ func TestEngine_UpdateNetworkMap(t *testing.T) {
|
|||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||||
WgIfaceName: "utun102",
|
WgIfaceName: "utun102",
|
||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
}, EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{})
|
||||||
|
|
||||||
wgIface := &MockWGIface{
|
wgIface := &MockWGIface{
|
||||||
NameFunc: func() string { return "utun102" },
|
NameFunc: func() string { return "utun102" },
|
||||||
@@ -647,13 +652,18 @@ func TestEngine_Sync(t *testing.T) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{SyncFunc: syncFunc}, relayMgr, &EngineConfig{
|
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||||
WgIfaceName: "utun103",
|
WgIfaceName: "utun103",
|
||||||
WgAddr: "100.64.0.1/24",
|
WgAddr: "100.64.0.1/24",
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
}, EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{SyncFunc: syncFunc},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{})
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
engine.dnsServer = &dns.MockServer{
|
engine.dnsServer = &dns.MockServer{
|
||||||
@@ -812,13 +822,18 @@ func TestEngine_UpdateNetworkMapWithRoutes(t *testing.T) {
|
|||||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||||
WgIfaceName: wgIfaceName,
|
WgIfaceName: wgIfaceName,
|
||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
}, EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{})
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -1014,13 +1029,18 @@ func TestEngine_UpdateNetworkMapWithDNSUpdate(t *testing.T) {
|
|||||||
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
wgAddr := fmt.Sprintf("100.66.%d.1/24", n)
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
engine := NewEngine(ctx, cancel, &signal.MockClient{}, &mgmt.MockClient{}, relayMgr, &EngineConfig{
|
engine := NewEngine(ctx, cancel, &EngineConfig{
|
||||||
WgIfaceName: wgIfaceName,
|
WgIfaceName: wgIfaceName,
|
||||||
WgAddr: wgAddr,
|
WgAddr: wgAddr,
|
||||||
WgPrivateKey: key,
|
WgPrivateKey: key,
|
||||||
WgPort: 33100,
|
WgPort: 33100,
|
||||||
MTU: iface.DefaultMTU,
|
MTU: iface.DefaultMTU,
|
||||||
}, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil)
|
}, EngineServices{
|
||||||
|
SignalClient: &signal.MockClient{},
|
||||||
|
MgmClient: &mgmt.MockClient{},
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{})
|
||||||
engine.ctx = ctx
|
engine.ctx = ctx
|
||||||
|
|
||||||
newNet, err := stdnet.NewNet(context.Background(), nil)
|
newNet, err := stdnet.NewNet(context.Background(), nil)
|
||||||
@@ -1546,7 +1566,12 @@ func createEngine(ctx context.Context, cancel context.CancelFunc, setupKey strin
|
|||||||
}
|
}
|
||||||
|
|
||||||
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
relayMgr := relayClient.NewManager(ctx, nil, key.PublicKey().String(), iface.DefaultMTU)
|
||||||
e, err := NewEngine(ctx, cancel, signalClient, mgmtClient, relayMgr, conf, MobileDependency{}, peer.NewRecorder("https://mgm"), nil, nil), nil
|
e, err := NewEngine(ctx, cancel, conf, EngineServices{
|
||||||
|
SignalClient: signalClient,
|
||||||
|
MgmClient: mgmtClient,
|
||||||
|
RelayManager: relayMgr,
|
||||||
|
StatusRecorder: peer.NewRecorder("https://mgm"),
|
||||||
|
}, MobileDependency{}), nil
|
||||||
e.ctx = ctx
|
e.ctx = ctx
|
||||||
return e, err
|
return e, err
|
||||||
}
|
}
|
||||||
|
|||||||
104
client/internal/expose/manager.go
Normal file
104
client/internal/expose/manager.go
Normal file
@@ -0,0 +1,104 @@
|
|||||||
|
package expose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
renewTimeout = 10 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response holds the response from exposing a service.
|
||||||
|
type Response struct {
|
||||||
|
ServiceName string
|
||||||
|
ServiceURL string
|
||||||
|
Domain string
|
||||||
|
PortAutoAssigned bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Request holds the parameters for exposing a local service via the management server.
|
||||||
|
// It is part of the embed API surface and exposed via a type alias.
|
||||||
|
type Request struct {
|
||||||
|
NamePrefix string
|
||||||
|
Domain string
|
||||||
|
Port uint16
|
||||||
|
Protocol ProtocolType
|
||||||
|
Pin string
|
||||||
|
Password string
|
||||||
|
UserGroups []string
|
||||||
|
ListenPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
type ManagementClient interface {
|
||||||
|
CreateExpose(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error)
|
||||||
|
RenewExpose(ctx context.Context, domain string) error
|
||||||
|
StopExpose(ctx context.Context, domain string) error
|
||||||
|
}
|
||||||
|
|
||||||
|
// Manager handles expose session lifecycle via the management client.
|
||||||
|
type Manager struct {
|
||||||
|
mgmClient ManagementClient
|
||||||
|
ctx context.Context
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewManager creates a new expose Manager using the given management client.
|
||||||
|
func NewManager(ctx context.Context, mgmClient ManagementClient) *Manager {
|
||||||
|
return &Manager{mgmClient: mgmClient, ctx: ctx}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Expose creates a new expose session via the management server.
|
||||||
|
func (m *Manager) Expose(ctx context.Context, req Request) (*Response, error) {
|
||||||
|
log.Infof("exposing service on port %d", req.Port)
|
||||||
|
resp, err := m.mgmClient.CreateExpose(ctx, toClientExposeRequest(req))
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("expose session created for %s", resp.Domain)
|
||||||
|
|
||||||
|
return fromClientExposeResponse(resp), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// KeepAlive periodically renews the expose session for the given domain until the context is canceled or an error occurs.
|
||||||
|
// It is part of the embed API surface and exposed via a type alias.
|
||||||
|
func (m *Manager) KeepAlive(ctx context.Context, domain string) error {
|
||||||
|
ticker := time.NewTicker(30 * time.Second)
|
||||||
|
defer ticker.Stop()
|
||||||
|
defer m.stop(domain)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
log.Infof("context canceled, stopping keep alive for %s", domain)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
case <-ticker.C:
|
||||||
|
if err := m.renew(ctx, domain); err != nil {
|
||||||
|
log.Errorf("renewing expose session for %s: %v", domain, err)
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// renew extends the TTL of an active expose session.
|
||||||
|
func (m *Manager) renew(ctx context.Context, domain string) error {
|
||||||
|
renewCtx, cancel := context.WithTimeout(ctx, renewTimeout)
|
||||||
|
defer cancel()
|
||||||
|
return m.mgmClient.RenewExpose(renewCtx, domain)
|
||||||
|
}
|
||||||
|
|
||||||
|
// stop terminates an active expose session.
|
||||||
|
func (m *Manager) stop(domain string) {
|
||||||
|
stopCtx, cancel := context.WithTimeout(m.ctx, renewTimeout)
|
||||||
|
defer cancel()
|
||||||
|
err := m.mgmClient.StopExpose(stopCtx, domain)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed stopping expose session for %s: %v", domain, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
95
client/internal/expose/manager_test.go
Normal file
95
client/internal/expose/manager_test.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package expose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestManager_Expose_Success(t *testing.T) {
|
||||||
|
mock := &mgm.MockClient{
|
||||||
|
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||||
|
return &mgm.ExposeResponse{
|
||||||
|
ServiceName: "my-service",
|
||||||
|
ServiceURL: "https://my-service.example.com",
|
||||||
|
Domain: "my-service.example.com",
|
||||||
|
}, nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(context.Background(), mock)
|
||||||
|
result, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Equal(t, "my-service", result.ServiceName, "service name should match")
|
||||||
|
assert.Equal(t, "https://my-service.example.com", result.ServiceURL, "service URL should match")
|
||||||
|
assert.Equal(t, "my-service.example.com", result.Domain, "domain should match")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Expose_Error(t *testing.T) {
|
||||||
|
mock := &mgm.MockClient{
|
||||||
|
CreateExposeFunc: func(ctx context.Context, req mgm.ExposeRequest) (*mgm.ExposeResponse, error) {
|
||||||
|
return nil, errors.New("permission denied")
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(context.Background(), mock)
|
||||||
|
_, err := m.Expose(context.Background(), Request{Port: 8080})
|
||||||
|
require.Error(t, err)
|
||||||
|
assert.Contains(t, err.Error(), "permission denied", "error should propagate")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Renew_Success(t *testing.T) {
|
||||||
|
mock := &mgm.MockClient{
|
||||||
|
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||||
|
assert.Equal(t, "my-service.example.com", domain, "domain should be passed through")
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(context.Background(), mock)
|
||||||
|
err := m.renew(context.Background(), "my-service.example.com")
|
||||||
|
require.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestManager_Renew_Timeout(t *testing.T) {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
cancel()
|
||||||
|
|
||||||
|
mock := &mgm.MockClient{
|
||||||
|
RenewExposeFunc: func(ctx context.Context, domain string) error {
|
||||||
|
return ctx.Err()
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m := NewManager(ctx, mock)
|
||||||
|
err := m.renew(ctx, "my-service.example.com")
|
||||||
|
require.Error(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewRequest(t *testing.T) {
|
||||||
|
req := &daemonProto.ExposeServiceRequest{
|
||||||
|
Port: 8080,
|
||||||
|
Protocol: daemonProto.ExposeProtocol_EXPOSE_HTTPS,
|
||||||
|
Pin: "123456",
|
||||||
|
Password: "secret",
|
||||||
|
UserGroups: []string{"group1", "group2"},
|
||||||
|
Domain: "custom.example.com",
|
||||||
|
NamePrefix: "my-prefix",
|
||||||
|
}
|
||||||
|
|
||||||
|
exposeReq := NewRequest(req)
|
||||||
|
|
||||||
|
assert.Equal(t, uint16(8080), exposeReq.Port, "port should match")
|
||||||
|
assert.Equal(t, ProtocolType(daemonProto.ExposeProtocol_EXPOSE_HTTPS), exposeReq.Protocol, "protocol should match")
|
||||||
|
assert.Equal(t, "123456", exposeReq.Pin, "pin should match")
|
||||||
|
assert.Equal(t, "secret", exposeReq.Password, "password should match")
|
||||||
|
assert.Equal(t, []string{"group1", "group2"}, exposeReq.UserGroups, "user groups should match")
|
||||||
|
assert.Equal(t, "custom.example.com", exposeReq.Domain, "domain should match")
|
||||||
|
assert.Equal(t, "my-prefix", exposeReq.NamePrefix, "name prefix should match")
|
||||||
|
}
|
||||||
40
client/internal/expose/protocol.go
Normal file
40
client/internal/expose/protocol.go
Normal file
@@ -0,0 +1,40 @@
|
|||||||
|
package expose
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ProtocolType represents the protocol used for exposing a service.
|
||||||
|
type ProtocolType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProtocolHTTP exposes the service as HTTP.
|
||||||
|
ProtocolHTTP ProtocolType = 0
|
||||||
|
// ProtocolHTTPS exposes the service as HTTPS.
|
||||||
|
ProtocolHTTPS ProtocolType = 1
|
||||||
|
// ProtocolTCP exposes the service as TCP.
|
||||||
|
ProtocolTCP ProtocolType = 2
|
||||||
|
// ProtocolUDP exposes the service as UDP.
|
||||||
|
ProtocolUDP ProtocolType = 3
|
||||||
|
// ProtocolTLS exposes the service as TLS.
|
||||||
|
ProtocolTLS ProtocolType = 4
|
||||||
|
)
|
||||||
|
|
||||||
|
// ParseProtocolType parses a protocol string into a ProtocolType.
|
||||||
|
func ParseProtocolType(s string) (ProtocolType, error) {
|
||||||
|
switch strings.ToLower(s) {
|
||||||
|
case "http":
|
||||||
|
return ProtocolHTTP, nil
|
||||||
|
case "https":
|
||||||
|
return ProtocolHTTPS, nil
|
||||||
|
case "tcp":
|
||||||
|
return ProtocolTCP, nil
|
||||||
|
case "udp":
|
||||||
|
return ProtocolUDP, nil
|
||||||
|
case "tls":
|
||||||
|
return ProtocolTLS, nil
|
||||||
|
default:
|
||||||
|
return 0, fmt.Errorf("unsupported protocol %q: must be http, https, tcp, udp, or tls", s)
|
||||||
|
}
|
||||||
|
}
|
||||||
42
client/internal/expose/request.go
Normal file
42
client/internal/expose/request.go
Normal file
@@ -0,0 +1,42 @@
|
|||||||
|
package expose
|
||||||
|
|
||||||
|
import (
|
||||||
|
daemonProto "github.com/netbirdio/netbird/client/proto"
|
||||||
|
mgm "github.com/netbirdio/netbird/shared/management/client"
|
||||||
|
)
|
||||||
|
|
||||||
|
// NewRequest converts a daemon ExposeServiceRequest to a management ExposeServiceRequest.
|
||||||
|
func NewRequest(req *daemonProto.ExposeServiceRequest) *Request {
|
||||||
|
return &Request{
|
||||||
|
Port: uint16(req.Port),
|
||||||
|
Protocol: ProtocolType(req.Protocol),
|
||||||
|
Pin: req.Pin,
|
||||||
|
Password: req.Password,
|
||||||
|
UserGroups: req.UserGroups,
|
||||||
|
Domain: req.Domain,
|
||||||
|
NamePrefix: req.NamePrefix,
|
||||||
|
ListenPort: uint16(req.ListenPort),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func toClientExposeRequest(req Request) mgm.ExposeRequest {
|
||||||
|
return mgm.ExposeRequest{
|
||||||
|
NamePrefix: req.NamePrefix,
|
||||||
|
Domain: req.Domain,
|
||||||
|
Port: req.Port,
|
||||||
|
Protocol: int(req.Protocol),
|
||||||
|
Pin: req.Pin,
|
||||||
|
Password: req.Password,
|
||||||
|
UserGroups: req.UserGroups,
|
||||||
|
ListenPort: req.ListenPort,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func fromClientExposeResponse(response *mgm.ExposeResponse) *Response {
|
||||||
|
return &Response{
|
||||||
|
ServiceName: response.ServiceName,
|
||||||
|
Domain: response.Domain,
|
||||||
|
ServiceURL: response.ServiceURL,
|
||||||
|
PortAutoAssigned: response.PortAutoAssigned,
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -11,6 +11,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
peerid "github.com/netbirdio/netbird/client/internal/peer/id"
|
||||||
@@ -74,12 +75,13 @@ func (m *Manager) createListener(peerCfg lazyconn.PeerConfig) (listener, error)
|
|||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
// BindListener is only used on Windows and JS platforms:
|
// BindListener is used on Windows, JS, and netstack platforms:
|
||||||
// - JS: Cannot listen to UDP sockets
|
// - JS: Cannot listen to UDP sockets
|
||||||
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
// - Windows: IP_UNICAST_IF socket option forces packets out the interface the default
|
||||||
// gateway points to, preventing them from reaching the loopback interface.
|
// gateway points to, preventing them from reaching the loopback interface.
|
||||||
// BindListener bypasses this by passing data directly through the bind.
|
// - Netstack: Allows multiple instances on the same host without port conflicts.
|
||||||
if runtime.GOOS != "windows" && runtime.GOOS != "js" {
|
// BindListener bypasses these issues by passing data directly through the bind.
|
||||||
|
if runtime.GOOS != "windows" && runtime.GOOS != "js" && !netstack.IsEnabled() {
|
||||||
return NewUDPListener(m.wgIface, peerCfg)
|
return NewUDPListener(m.wgIface, peerCfg)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
17
client/internal/metrics/connection_type.go
Normal file
17
client/internal/metrics/connection_type.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
// ConnectionType represents the type of peer connection
|
||||||
|
type ConnectionType string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ConnectionTypeICE represents a direct peer-to-peer connection using ICE
|
||||||
|
ConnectionTypeICE ConnectionType = "ice"
|
||||||
|
|
||||||
|
// ConnectionTypeRelay represents a relayed connection
|
||||||
|
ConnectionTypeRelay ConnectionType = "relay"
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns the string representation of the connection type
|
||||||
|
func (c ConnectionType) String() string {
|
||||||
|
return string(c)
|
||||||
|
}
|
||||||
51
client/internal/metrics/deployment_type.go
Normal file
51
client/internal/metrics/deployment_type.go
Normal file
@@ -0,0 +1,51 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"strings"
|
||||||
|
)
|
||||||
|
|
||||||
|
// DeploymentType represents the type of NetBird deployment
|
||||||
|
type DeploymentType int
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DeploymentTypeUnknown represents an unknown or uninitialized deployment type
|
||||||
|
DeploymentTypeUnknown DeploymentType = iota
|
||||||
|
|
||||||
|
// DeploymentTypeCloud represents a cloud-hosted NetBird deployment
|
||||||
|
DeploymentTypeCloud
|
||||||
|
|
||||||
|
// DeploymentTypeSelfHosted represents a self-hosted NetBird deployment
|
||||||
|
DeploymentTypeSelfHosted
|
||||||
|
)
|
||||||
|
|
||||||
|
// String returns the string representation of the deployment type
|
||||||
|
func (d DeploymentType) String() string {
|
||||||
|
switch d {
|
||||||
|
case DeploymentTypeCloud:
|
||||||
|
return "cloud"
|
||||||
|
case DeploymentTypeSelfHosted:
|
||||||
|
return "selfhosted"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DetermineDeploymentType determines if the deployment is cloud or self-hosted
|
||||||
|
// based on the management URL string
|
||||||
|
func DetermineDeploymentType(managementURL string) DeploymentType {
|
||||||
|
if managementURL == "" {
|
||||||
|
return DeploymentTypeUnknown
|
||||||
|
}
|
||||||
|
|
||||||
|
u, err := url.Parse(managementURL)
|
||||||
|
if err != nil {
|
||||||
|
return DeploymentTypeSelfHosted
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.ToLower(u.Hostname()) == "api.netbird.io" {
|
||||||
|
return DeploymentTypeCloud
|
||||||
|
}
|
||||||
|
|
||||||
|
return DeploymentTypeSelfHosted
|
||||||
|
}
|
||||||
93
client/internal/metrics/env.go
Normal file
93
client/internal/metrics/env.go
Normal file
@@ -0,0 +1,93 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/url"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// EnvMetricsPushEnabled controls whether collected metrics are pushed to the backend.
|
||||||
|
// Metrics collection itself is always active (for debug bundles).
|
||||||
|
// Disabled by default. Set NB_METRICS_PUSH_ENABLED=true to enable push.
|
||||||
|
EnvMetricsPushEnabled = "NB_METRICS_PUSH_ENABLED"
|
||||||
|
|
||||||
|
// EnvMetricsForceSending if set to true, skips remote configuration fetch and forces metric sending
|
||||||
|
EnvMetricsForceSending = "NB_METRICS_FORCE_SENDING"
|
||||||
|
|
||||||
|
// EnvMetricsConfigURL is the environment variable to override the metrics push config ServerAddress
|
||||||
|
EnvMetricsConfigURL = "NB_METRICS_CONFIG_URL"
|
||||||
|
|
||||||
|
// EnvMetricsServerURL is the environment variable to override the metrics server address.
|
||||||
|
// When set, this takes precedence over the server_url from remote push config.
|
||||||
|
EnvMetricsServerURL = "NB_METRICS_SERVER_URL"
|
||||||
|
|
||||||
|
// EnvMetricsInterval overrides the push interval from the remote config.
|
||||||
|
// Only affects how often metrics are pushed; remote config availability
|
||||||
|
// and version range checks are still respected.
|
||||||
|
// Format: duration string like "1h", "30m", "4h"
|
||||||
|
EnvMetricsInterval = "NB_METRICS_INTERVAL"
|
||||||
|
|
||||||
|
defaultMetricsConfigURL = "https://ingest.netbird.io/config"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IsMetricsPushEnabled returns true if metrics push is enabled via NB_METRICS_PUSH_ENABLED env var.
|
||||||
|
// Disabled by default. Metrics collection is always active for debug bundles.
|
||||||
|
func IsMetricsPushEnabled() bool {
|
||||||
|
enabled, _ := strconv.ParseBool(os.Getenv(EnvMetricsPushEnabled))
|
||||||
|
return enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetricsInterval returns the metrics push interval from NB_METRICS_INTERVAL env var.
|
||||||
|
// Returns 0 if not set or invalid.
|
||||||
|
func getMetricsInterval() time.Duration {
|
||||||
|
intervalStr := os.Getenv(EnvMetricsInterval)
|
||||||
|
if intervalStr == "" {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
interval, err := time.ParseDuration(intervalStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("invalid metrics interval from env %q: %v", intervalStr, err)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
if interval <= 0 {
|
||||||
|
log.Warnf("invalid metrics interval from env %q: must be positive", intervalStr)
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
return interval
|
||||||
|
}
|
||||||
|
|
||||||
|
func isForceSending() bool {
|
||||||
|
force, _ := strconv.ParseBool(os.Getenv(EnvMetricsForceSending))
|
||||||
|
return force
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetricsConfigURL returns the URL to fetch push configuration from
|
||||||
|
func getMetricsConfigURL() string {
|
||||||
|
if envURL := os.Getenv(EnvMetricsConfigURL); envURL != "" {
|
||||||
|
return envURL
|
||||||
|
}
|
||||||
|
return defaultMetricsConfigURL
|
||||||
|
}
|
||||||
|
|
||||||
|
// getMetricsServerURL returns the metrics server URL from NB_METRICS_SERVER_URL env var.
|
||||||
|
// Returns nil if not set or invalid.
|
||||||
|
func getMetricsServerURL() *url.URL {
|
||||||
|
envURL := os.Getenv(EnvMetricsServerURL)
|
||||||
|
if envURL == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
parsed, err := url.ParseRequestURI(envURL)
|
||||||
|
if err != nil || parsed.Host == "" {
|
||||||
|
log.Warnf("invalid metrics server URL %q: must be an absolute HTTP(S) URL", envURL)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if parsed.Scheme != "http" && parsed.Scheme != "https" {
|
||||||
|
log.Warnf("invalid metrics server URL %q: unsupported scheme %q", envURL, parsed.Scheme)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return parsed
|
||||||
|
}
|
||||||
219
client/internal/metrics/influxdb.go
Normal file
219
client/internal/metrics/influxdb.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"maps"
|
||||||
|
"slices"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxSampleAge = 5 * 24 * time.Hour // drop samples older than 5 days
|
||||||
|
maxBufferSize = 5 * 1024 * 1024 // drop oldest samples when estimated size exceeds 5 MB
|
||||||
|
// estimatedSampleSize is a rough per-sample memory estimate (measurement + tags + fields + timestamp)
|
||||||
|
estimatedSampleSize = 256
|
||||||
|
)
|
||||||
|
|
||||||
|
// influxSample is a single InfluxDB line protocol entry.
|
||||||
|
type influxSample struct {
|
||||||
|
measurement string
|
||||||
|
tags string
|
||||||
|
fields map[string]float64
|
||||||
|
timestamp time.Time
|
||||||
|
}
|
||||||
|
|
||||||
|
// influxDBMetrics collects metric events as timestamped samples.
|
||||||
|
// Each event is recorded with its exact timestamp, pushed once, then cleared.
|
||||||
|
type influxDBMetrics struct {
|
||||||
|
mu sync.Mutex
|
||||||
|
samples []influxSample
|
||||||
|
}
|
||||||
|
|
||||||
|
func newInfluxDBMetrics() metricsImplementation {
|
||||||
|
return &influxDBMetrics{}
|
||||||
|
}
|
||||||
|
func (m *influxDBMetrics) RecordConnectionStages(
|
||||||
|
_ context.Context,
|
||||||
|
agentInfo AgentInfo,
|
||||||
|
connectionPairID string,
|
||||||
|
connectionType ConnectionType,
|
||||||
|
isReconnection bool,
|
||||||
|
timestamps ConnectionStageTimestamps,
|
||||||
|
) {
|
||||||
|
var signalingReceivedToConnection, connectionToWgHandshake, totalDuration float64
|
||||||
|
|
||||||
|
if !timestamps.SignalingReceived.IsZero() && !timestamps.ConnectionReady.IsZero() {
|
||||||
|
signalingReceivedToConnection = timestamps.ConnectionReady.Sub(timestamps.SignalingReceived).Seconds()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !timestamps.ConnectionReady.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
|
||||||
|
connectionToWgHandshake = timestamps.WgHandshakeSuccess.Sub(timestamps.ConnectionReady).Seconds()
|
||||||
|
}
|
||||||
|
|
||||||
|
if !timestamps.SignalingReceived.IsZero() && !timestamps.WgHandshakeSuccess.IsZero() {
|
||||||
|
totalDuration = timestamps.WgHandshakeSuccess.Sub(timestamps.SignalingReceived).Seconds()
|
||||||
|
}
|
||||||
|
|
||||||
|
attemptType := "initial"
|
||||||
|
if isReconnection {
|
||||||
|
attemptType = "reconnection"
|
||||||
|
}
|
||||||
|
|
||||||
|
connTypeStr := connectionType.String()
|
||||||
|
tags := fmt.Sprintf("deployment_type=%s,connection_type=%s,attempt_type=%s,version=%s,os=%s,arch=%s,peer_id=%s,connection_pair_id=%s",
|
||||||
|
agentInfo.DeploymentType.String(),
|
||||||
|
connTypeStr,
|
||||||
|
attemptType,
|
||||||
|
agentInfo.Version,
|
||||||
|
agentInfo.OS,
|
||||||
|
agentInfo.Arch,
|
||||||
|
agentInfo.peerID,
|
||||||
|
connectionPairID,
|
||||||
|
)
|
||||||
|
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.samples = append(m.samples, influxSample{
|
||||||
|
measurement: "netbird_peer_connection",
|
||||||
|
tags: tags,
|
||||||
|
fields: map[string]float64{
|
||||||
|
"signaling_to_connection_seconds": signalingReceivedToConnection,
|
||||||
|
"connection_to_wg_handshake_seconds": connectionToWgHandshake,
|
||||||
|
"total_seconds": totalDuration,
|
||||||
|
},
|
||||||
|
timestamp: now,
|
||||||
|
})
|
||||||
|
m.trimLocked()
|
||||||
|
|
||||||
|
log.Tracef("peer connection metrics [%s, %s, %s]: signalingReceived→connection: %.3fs, connection→wg_handshake: %.3fs, total: %.3fs",
|
||||||
|
agentInfo.DeploymentType.String(), connTypeStr, attemptType, signalingReceivedToConnection, connectionToWgHandshake, totalDuration)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *influxDBMetrics) RecordSyncDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration) {
|
||||||
|
tags := fmt.Sprintf("deployment_type=%s,version=%s,os=%s,arch=%s,peer_id=%s",
|
||||||
|
agentInfo.DeploymentType.String(),
|
||||||
|
agentInfo.Version,
|
||||||
|
agentInfo.OS,
|
||||||
|
agentInfo.Arch,
|
||||||
|
agentInfo.peerID,
|
||||||
|
)
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.samples = append(m.samples, influxSample{
|
||||||
|
measurement: "netbird_sync",
|
||||||
|
tags: tags,
|
||||||
|
fields: map[string]float64{
|
||||||
|
"duration_seconds": duration.Seconds(),
|
||||||
|
},
|
||||||
|
timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
m.trimLocked()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *influxDBMetrics) RecordLoginDuration(_ context.Context, agentInfo AgentInfo, duration time.Duration, success bool) {
|
||||||
|
result := "success"
|
||||||
|
if !success {
|
||||||
|
result = "failure"
|
||||||
|
}
|
||||||
|
|
||||||
|
tags := fmt.Sprintf("deployment_type=%s,result=%s,version=%s,os=%s,arch=%s,peer_id=%s",
|
||||||
|
agentInfo.DeploymentType.String(),
|
||||||
|
result,
|
||||||
|
agentInfo.Version,
|
||||||
|
agentInfo.OS,
|
||||||
|
agentInfo.Arch,
|
||||||
|
agentInfo.peerID,
|
||||||
|
)
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
|
||||||
|
m.samples = append(m.samples, influxSample{
|
||||||
|
measurement: "netbird_login",
|
||||||
|
tags: tags,
|
||||||
|
fields: map[string]float64{
|
||||||
|
"duration_seconds": duration.Seconds(),
|
||||||
|
},
|
||||||
|
timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
m.trimLocked()
|
||||||
|
|
||||||
|
log.Tracef("login metrics [%s, %s]: duration=%.3fs", agentInfo.DeploymentType.String(), result, duration.Seconds())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Export writes pending samples in InfluxDB line protocol format.
|
||||||
|
// Format: measurement,tag=val,tag=val field=val,field=val timestamp_ns
|
||||||
|
func (m *influxDBMetrics) Export(w io.Writer) error {
|
||||||
|
m.mu.Lock()
|
||||||
|
samples := make([]influxSample, len(m.samples))
|
||||||
|
copy(samples, m.samples)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
for _, s := range samples {
|
||||||
|
if _, err := fmt.Fprintf(w, "%s,%s ", s.measurement, s.tags); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
sortedKeys := slices.Sorted(maps.Keys(s.fields))
|
||||||
|
first := true
|
||||||
|
for _, k := range sortedKeys {
|
||||||
|
if !first {
|
||||||
|
if _, err := fmt.Fprint(w, ","); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if _, err := fmt.Fprintf(w, "%s=%g", k, s.fields[k]); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
first = false
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := fmt.Fprintf(w, " %d\n", s.timestamp.UnixNano()); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Reset clears pending samples after a successful push
|
||||||
|
func (m *influxDBMetrics) Reset() {
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
|
m.samples = m.samples[:0]
|
||||||
|
}
|
||||||
|
|
||||||
|
// trimLocked removes samples that exceed age or size limits.
|
||||||
|
// Must be called with m.mu held.
|
||||||
|
func (m *influxDBMetrics) trimLocked() {
|
||||||
|
now := time.Now()
|
||||||
|
|
||||||
|
// drop samples older than maxSampleAge
|
||||||
|
cutoff := 0
|
||||||
|
for cutoff < len(m.samples) && now.Sub(m.samples[cutoff].timestamp) > maxSampleAge {
|
||||||
|
cutoff++
|
||||||
|
}
|
||||||
|
if cutoff > 0 {
|
||||||
|
copy(m.samples, m.samples[cutoff:])
|
||||||
|
m.samples = m.samples[:len(m.samples)-cutoff]
|
||||||
|
log.Debugf("influxdb metrics: dropped %d samples older than %s", cutoff, maxSampleAge)
|
||||||
|
}
|
||||||
|
|
||||||
|
// drop oldest samples if estimated size exceeds maxBufferSize
|
||||||
|
maxSamples := maxBufferSize / estimatedSampleSize
|
||||||
|
if len(m.samples) > maxSamples {
|
||||||
|
drop := len(m.samples) - maxSamples
|
||||||
|
copy(m.samples, m.samples[drop:])
|
||||||
|
m.samples = m.samples[:maxSamples]
|
||||||
|
log.Debugf("influxdb metrics: dropped %d oldest samples to stay under %d MB size limit", drop, maxBufferSize/(1024*1024))
|
||||||
|
}
|
||||||
|
}
|
||||||
229
client/internal/metrics/influxdb_test.go
Normal file
229
client/internal/metrics/influxdb_test.go
Normal file
@@ -0,0 +1,229 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"context"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestInfluxDBMetrics_RecordAndExport(t *testing.T) {
|
||||||
|
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||||
|
|
||||||
|
agentInfo := AgentInfo{
|
||||||
|
DeploymentType: DeploymentTypeCloud,
|
||||||
|
Version: "1.0.0",
|
||||||
|
OS: "linux",
|
||||||
|
Arch: "amd64",
|
||||||
|
peerID: "abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := ConnectionStageTimestamps{
|
||||||
|
SignalingReceived: time.Now().Add(-3 * time.Second),
|
||||||
|
ConnectionReady: time.Now().Add(-2 * time.Second),
|
||||||
|
WgHandshakeSuccess: time.Now().Add(-1 * time.Second),
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := m.Export(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
assert.Contains(t, output, "netbird_peer_connection,")
|
||||||
|
assert.Contains(t, output, "connection_to_wg_handshake_seconds=")
|
||||||
|
assert.Contains(t, output, "signaling_to_connection_seconds=")
|
||||||
|
assert.Contains(t, output, "total_seconds=")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfluxDBMetrics_ExportDeterministicFieldOrder(t *testing.T) {
|
||||||
|
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||||
|
|
||||||
|
agentInfo := AgentInfo{
|
||||||
|
DeploymentType: DeploymentTypeCloud,
|
||||||
|
Version: "1.0.0",
|
||||||
|
OS: "linux",
|
||||||
|
Arch: "amd64",
|
||||||
|
peerID: "abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
ts := ConnectionStageTimestamps{
|
||||||
|
SignalingReceived: time.Now().Add(-3 * time.Second),
|
||||||
|
ConnectionReady: time.Now().Add(-2 * time.Second),
|
||||||
|
WgHandshakeSuccess: time.Now().Add(-1 * time.Second),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Record multiple times and verify consistent field order
|
||||||
|
for i := 0; i < 10; i++ {
|
||||||
|
m.RecordConnectionStages(context.Background(), agentInfo, "pair123", ConnectionTypeICE, false, ts)
|
||||||
|
}
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := m.Export(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
lines := strings.Split(strings.TrimSpace(buf.String()), "\n")
|
||||||
|
require.Len(t, lines, 10)
|
||||||
|
|
||||||
|
// Extract field portion from each line and verify they're all identical
|
||||||
|
var fieldSections []string
|
||||||
|
for _, line := range lines {
|
||||||
|
parts := strings.SplitN(line, " ", 3)
|
||||||
|
require.Len(t, parts, 3, "each line should have measurement, fields, timestamp")
|
||||||
|
fieldSections = append(fieldSections, parts[1])
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 1; i < len(fieldSections); i++ {
|
||||||
|
assert.Equal(t, fieldSections[0], fieldSections[i], "field order should be deterministic across samples")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Fields should be alphabetically sorted
|
||||||
|
assert.True(t, strings.HasPrefix(fieldSections[0], "connection_to_wg_handshake_seconds="),
|
||||||
|
"fields should be sorted: connection_to_wg < signaling_to < total")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfluxDBMetrics_RecordSyncDuration(t *testing.T) {
|
||||||
|
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||||
|
|
||||||
|
agentInfo := AgentInfo{
|
||||||
|
DeploymentType: DeploymentTypeSelfHosted,
|
||||||
|
Version: "2.0.0",
|
||||||
|
OS: "darwin",
|
||||||
|
Arch: "arm64",
|
||||||
|
peerID: "def456",
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RecordSyncDuration(context.Background(), agentInfo, 1500*time.Millisecond)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := m.Export(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
assert.Contains(t, output, "netbird_sync,")
|
||||||
|
assert.Contains(t, output, "duration_seconds=1.5")
|
||||||
|
assert.Contains(t, output, "deployment_type=selfhosted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfluxDBMetrics_Reset(t *testing.T) {
|
||||||
|
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||||
|
|
||||||
|
agentInfo := AgentInfo{
|
||||||
|
DeploymentType: DeploymentTypeCloud,
|
||||||
|
Version: "1.0.0",
|
||||||
|
OS: "linux",
|
||||||
|
Arch: "amd64",
|
||||||
|
peerID: "abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RecordSyncDuration(context.Background(), agentInfo, time.Second)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := m.Export(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.NotEmpty(t, buf.String())
|
||||||
|
|
||||||
|
m.Reset()
|
||||||
|
|
||||||
|
buf.Reset()
|
||||||
|
err = m.Export(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, buf.String(), "should be empty after reset")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfluxDBMetrics_ExportEmpty(t *testing.T) {
|
||||||
|
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := m.Export(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
assert.Empty(t, buf.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfluxDBMetrics_TrimByAge(t *testing.T) {
|
||||||
|
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.samples = append(m.samples, influxSample{
|
||||||
|
measurement: "old",
|
||||||
|
tags: "t=1",
|
||||||
|
fields: map[string]float64{"v": 1},
|
||||||
|
timestamp: time.Now().Add(-maxSampleAge - time.Hour),
|
||||||
|
})
|
||||||
|
m.trimLocked()
|
||||||
|
remaining := len(m.samples)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
assert.Equal(t, 0, remaining, "old samples should be trimmed")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfluxDBMetrics_RecordLoginDuration(t *testing.T) {
|
||||||
|
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||||
|
|
||||||
|
agentInfo := AgentInfo{
|
||||||
|
DeploymentType: DeploymentTypeCloud,
|
||||||
|
Version: "1.0.0",
|
||||||
|
OS: "linux",
|
||||||
|
Arch: "amd64",
|
||||||
|
peerID: "abc123",
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RecordLoginDuration(context.Background(), agentInfo, 2500*time.Millisecond, true)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := m.Export(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
assert.Contains(t, output, "netbird_login,")
|
||||||
|
assert.Contains(t, output, "duration_seconds=2.5")
|
||||||
|
assert.Contains(t, output, "result=success")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfluxDBMetrics_RecordLoginDurationFailure(t *testing.T) {
|
||||||
|
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||||
|
|
||||||
|
agentInfo := AgentInfo{
|
||||||
|
DeploymentType: DeploymentTypeSelfHosted,
|
||||||
|
Version: "1.0.0",
|
||||||
|
OS: "darwin",
|
||||||
|
Arch: "arm64",
|
||||||
|
peerID: "xyz789",
|
||||||
|
}
|
||||||
|
|
||||||
|
m.RecordLoginDuration(context.Background(), agentInfo, 5*time.Second, false)
|
||||||
|
|
||||||
|
var buf bytes.Buffer
|
||||||
|
err := m.Export(&buf)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
output := buf.String()
|
||||||
|
assert.Contains(t, output, "netbird_login,")
|
||||||
|
assert.Contains(t, output, "result=failure")
|
||||||
|
assert.Contains(t, output, "deployment_type=selfhosted")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestInfluxDBMetrics_TrimBySize(t *testing.T) {
|
||||||
|
m := newInfluxDBMetrics().(*influxDBMetrics)
|
||||||
|
|
||||||
|
maxSamples := maxBufferSize / estimatedSampleSize
|
||||||
|
m.mu.Lock()
|
||||||
|
for i := 0; i < maxSamples+100; i++ {
|
||||||
|
m.samples = append(m.samples, influxSample{
|
||||||
|
measurement: "test",
|
||||||
|
tags: "t=1",
|
||||||
|
fields: map[string]float64{"v": float64(i)},
|
||||||
|
timestamp: time.Now(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
m.trimLocked()
|
||||||
|
remaining := len(m.samples)
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
assert.Equal(t, maxSamples, remaining, "should trim to max samples")
|
||||||
|
}
|
||||||
16
client/internal/metrics/infra/.env.example
Normal file
16
client/internal/metrics/infra/.env.example
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
# Copy to .env and adjust values before running docker compose
|
||||||
|
|
||||||
|
# InfluxDB admin (server-side only, never exposed to clients)
|
||||||
|
INFLUXDB_ADMIN_PASSWORD=changeme
|
||||||
|
INFLUXDB_ADMIN_TOKEN=changeme
|
||||||
|
|
||||||
|
# Grafana admin credentials
|
||||||
|
GRAFANA_ADMIN_USER=admin
|
||||||
|
GRAFANA_ADMIN_PASSWORD=changeme
|
||||||
|
|
||||||
|
# Remote config served by ingest at /config
|
||||||
|
# Set CONFIG_METRICS_SERVER_URL to the ingest server's public address to enable
|
||||||
|
CONFIG_METRICS_SERVER_URL=
|
||||||
|
CONFIG_VERSION_SINCE=0.0.0
|
||||||
|
CONFIG_VERSION_UNTIL=99.99.99
|
||||||
|
CONFIG_PERIOD_MINUTES=5
|
||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user