mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-23 02:36:42 +00:00
Compare commits
189 Commits
v0.41.2
...
poc/prepro
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
51eee4c5ac | ||
|
|
8942c40fde | ||
|
|
fbb1b55beb | ||
|
|
77ec32dd6f | ||
|
|
8c09a55057 | ||
|
|
f603ddf35e | ||
|
|
996b8c600c | ||
|
|
c4ed11d447 | ||
|
|
9afbecb7ac | ||
|
|
2c81cf2c1e | ||
|
|
551cb4e467 | ||
|
|
57961afe95 | ||
|
|
22678bce7f | ||
|
|
6c633497bc | ||
|
|
6922826919 | ||
|
|
56a1a75e3f | ||
|
|
d9402168ad | ||
|
|
dbdef04b9e | ||
|
|
29cbfe8467 | ||
|
|
6ce8643368 | ||
|
|
07d1ad35fc | ||
|
|
ef6cd36f1a | ||
|
|
c1c71b6d39 | ||
|
|
0480507a10 | ||
|
|
34ac4e4b5a | ||
|
|
52ff9d9602 | ||
|
|
1b73fae46e | ||
|
|
d897365abc | ||
|
|
f37aa2cc9d | ||
|
|
5343bee7b2 | ||
|
|
870e29db63 | ||
|
|
08e9b05d51 | ||
|
|
3581648071 | ||
|
|
2a51609436 | ||
|
|
83457f8b99 | ||
|
|
b45284f086 | ||
|
|
e9016aecea | ||
|
|
23b5d45b68 | ||
|
|
0e5dc9d412 | ||
|
|
91f7ee6a3c | ||
|
|
7c6b85b4cb | ||
|
|
08c9107c61 | ||
|
|
81d83245e1 | ||
|
|
af2b427751 | ||
|
|
f61ebdb3bc | ||
|
|
de7384e8ea | ||
|
|
75c1be69cf | ||
|
|
424ae28de9 | ||
|
|
d4a800edd5 | ||
|
|
dd9917f1a8 | ||
|
|
8df8c1012f | ||
|
|
bfa5c21d2d | ||
|
|
b1247a14ba | ||
|
|
f595057a0b | ||
|
|
089d442fb2 | ||
|
|
04a3765391 | ||
|
|
d24d8328f9 | ||
|
|
4f63996ae8 | ||
|
|
bdf2994e97 | ||
|
|
6d654acbad | ||
|
|
3e43298471 | ||
|
|
0ad2590974 | ||
|
|
9d11257b1a | ||
|
|
4ee1635baa | ||
|
|
75feb0da8b | ||
|
|
87376afd13 | ||
|
|
b76d9e8e9e | ||
|
|
e71383dcb9 | ||
|
|
e002a2e6e8 | ||
|
|
6127a01196 | ||
|
|
de27d6df36 | ||
|
|
3c535cdd2b | ||
|
|
0f050e5fe1 | ||
|
|
0f7c7f1da2 | ||
|
|
b56f61bf1b | ||
|
|
64f111923e | ||
|
|
122a89c02b | ||
|
|
c6cceba381 | ||
|
|
6c0cdb6ed1 | ||
|
|
84354951d3 | ||
|
|
55957a1960 | ||
|
|
df82a45d99 | ||
|
|
9424b88db2 | ||
|
|
609654eee7 | ||
|
|
b604c66140 | ||
|
|
ea4d13e96d | ||
|
|
87148c503f | ||
|
|
0cd36baf67 | ||
|
|
06980e7fa0 | ||
|
|
1ce4ee0cef | ||
|
|
f367925496 | ||
|
|
616b19c064 | ||
|
|
af27aaf9af | ||
|
|
35287f8241 | ||
|
|
07b220d91b | ||
|
|
41cd4952f1 | ||
|
|
f16f0c7831 | ||
|
|
aa07b3b87b | ||
|
|
2bef214cc0 | ||
|
|
cfb2d82352 | ||
|
|
684501fd35 | ||
|
|
0492c1724a | ||
|
|
6f436e57b5 | ||
|
|
a0d28f9851 | ||
|
|
cdd27a9fe5 | ||
|
|
5523040acd | ||
|
|
670446d42e | ||
|
|
5bed6777d5 | ||
|
|
a0482ebc7b | ||
|
|
2a89d6e47a | ||
|
|
24f932b2ce | ||
|
|
c03435061c | ||
|
|
8e948739f1 | ||
|
|
9b53cad752 | ||
|
|
802a18167c | ||
|
|
e9108ffe6c | ||
|
|
e806d9de38 | ||
|
|
daa8380df9 | ||
|
|
4785f23fc4 | ||
|
|
1d4cfb83e7 | ||
|
|
207fa059d2 | ||
|
|
cbcdad7814 | ||
|
|
701c13807a | ||
|
|
99f8dc7748 | ||
|
|
f1de8e6eb0 | ||
|
|
b2a10780af | ||
|
|
43ae79d848 | ||
|
|
e520b64c6d | ||
|
|
92c91bbdd8 | ||
|
|
adf494e1ac | ||
|
|
2158461121 | ||
|
|
0cd4b601c3 | ||
|
|
ee1cec47b3 | ||
|
|
efb0edfc4c | ||
|
|
20f59ddecb | ||
|
|
2f34e984b0 | ||
|
|
d5b52e86b6 | ||
|
|
cad2fe1f39 | ||
|
|
fcd2c15a37 | ||
|
|
ebda0fc538 | ||
|
|
ac135ab11d | ||
|
|
25faf9283d | ||
|
|
59faaa99f6 | ||
|
|
9762b39f29 | ||
|
|
ffdd115ded | ||
|
|
055df9854c | ||
|
|
12f883badf | ||
|
|
2abb92b0d4 | ||
|
|
01c3719c5d | ||
|
|
7b64953eed | ||
|
|
9bc7d788f0 | ||
|
|
b5419ef11a | ||
|
|
d5081cef90 | ||
|
|
488e619ec7 | ||
|
|
d2b42c8f68 | ||
|
|
2f44fe2e23 | ||
|
|
d8dc107bee | ||
|
|
3fa915e271 | ||
|
|
47c3afe561 | ||
|
|
84bfecdd37 | ||
|
|
3cf87b6846 | ||
|
|
4fe4c2054d | ||
|
|
38ada44a0e | ||
|
|
dbf81a145e | ||
|
|
39483f8ca8 | ||
|
|
c0eaea938e | ||
|
|
ef8b8a2891 | ||
|
|
2817f62c13 | ||
|
|
4a9049566a | ||
|
|
85f92f8321 | ||
|
|
714beb6e3b | ||
|
|
400b9fca32 | ||
|
|
4013298e22 | ||
|
|
312bfd9bd7 | ||
|
|
8db05838ca | ||
|
|
c69df13515 | ||
|
|
986eb8c1e0 | ||
|
|
197761ba4d | ||
|
|
f74ea64c7b | ||
|
|
3b7b9d25bc | ||
|
|
1a6d6b3109 | ||
|
|
f686615876 | ||
|
|
a4311f574d | ||
|
|
0bb8eae903 | ||
|
|
e0b33d325d | ||
|
|
c38e07d89a | ||
|
|
a37368fff4 | ||
|
|
0c93bd3d06 | ||
|
|
a675531b5c |
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -37,17 +37,22 @@ If yes, which one?
|
|||||||
|
|
||||||
**Debug output**
|
**Debug output**
|
||||||
|
|
||||||
To help us resolve the problem, please attach the following debug output
|
To help us resolve the problem, please attach the following anonymized status output
|
||||||
|
|
||||||
netbird status -dA
|
netbird status -dA
|
||||||
|
|
||||||
As well as the file created by
|
Create and upload a debug bundle, and share the returned file key:
|
||||||
|
|
||||||
|
netbird debug for 1m -AS -U
|
||||||
|
|
||||||
|
*Uploaded files are automatically deleted after 30 days.*
|
||||||
|
|
||||||
|
|
||||||
|
Alternatively, create the file only and attach it here manually:
|
||||||
|
|
||||||
netbird debug for 1m -AS
|
netbird debug for 1m -AS
|
||||||
|
|
||||||
|
|
||||||
We advise reviewing the anonymized output for any remaining personal information.
|
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
|
|
||||||
If applicable, add screenshots to help explain your problem.
|
If applicable, add screenshots to help explain your problem.
|
||||||
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
|
|||||||
Add any other context about the problem here.
|
Add any other context about the problem here.
|
||||||
|
|
||||||
**Have you tried these troubleshooting steps?**
|
**Have you tried these troubleshooting steps?**
|
||||||
|
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||||
- [ ] Checked for newer NetBird versions
|
- [ ] Checked for newer NetBird versions
|
||||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||||
- [ ] Restarted the NetBird client
|
- [ ] Restarted the NetBird client
|
||||||
- [ ] Disabled other VPN software
|
- [ ] Disabled other VPN software
|
||||||
- [ ] Checked firewall settings
|
- [ ] Checked firewall settings
|
||||||
|
|
||||||
|
|||||||
2
.github/pull_request_template.md
vendored
2
.github/pull_request_template.md
vendored
@@ -13,3 +13,5 @@
|
|||||||
- [ ] It is a refactor
|
- [ ] It is a refactor
|
||||||
- [ ] Created tests that fail without the change (if possible)
|
- [ ] Created tests that fail without the change (if possible)
|
||||||
- [ ] Extended the README / documentation, if necessary
|
- [ ] Extended the README / documentation, if necessary
|
||||||
|
|
||||||
|
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||||
|
|||||||
233
.github/workflows/golang-test-linux.yml
vendored
233
.github/workflows/golang-test-linux.yml
vendored
@@ -146,6 +146,65 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||||
|
|
||||||
|
test_client_on_docker:
|
||||||
|
name: "Client (Docker) / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
id: go-env
|
||||||
|
run: |
|
||||||
|
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
id: cache-restore
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ steps.go-env.outputs.cache_dir }}
|
||||||
|
${{ steps.go-env.outputs.modcache_dir }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Run tests in container
|
||||||
|
env:
|
||||||
|
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
|
||||||
|
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
|
||||||
|
run: |
|
||||||
|
CONTAINER_GOCACHE="/root/.cache/go-build"
|
||||||
|
CONTAINER_GOMODCACHE="/go/pkg/mod"
|
||||||
|
|
||||||
|
docker run --rm \
|
||||||
|
--cap-add=NET_ADMIN \
|
||||||
|
--privileged \
|
||||||
|
-v $PWD:/app \
|
||||||
|
-w /app \
|
||||||
|
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
|
||||||
|
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
|
||||||
|
-e CGO_ENABLED=1 \
|
||||||
|
-e CI=true \
|
||||||
|
-e DOCKER_CI=true \
|
||||||
|
-e GOARCH=${GOARCH_TARGET} \
|
||||||
|
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||||
|
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||||
|
golang:1.23-alpine \
|
||||||
|
sh -c ' \
|
||||||
|
apk update; apk add --no-cache \
|
||||||
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
||||||
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
name: "Relay / Unit"
|
name: "Relay / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -164,6 +223,10 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -179,13 +242,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -217,6 +273,10 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -232,13 +292,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -286,13 +339,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -314,6 +360,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags=devcert \
|
go test -tags=devcert \
|
||||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
@@ -353,13 +400,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -380,10 +420,11 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags devcert -run=^$ -bench=. \
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
-timeout 20m ./...
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
api_benchmark:
|
api_benchmark:
|
||||||
name: "Management / Benchmark (API)"
|
name: "Management / Benchmark (API)"
|
||||||
@@ -396,6 +437,33 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres' ]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Create Docker network
|
||||||
|
run: docker network create promnet
|
||||||
|
|
||||||
|
- name: Start Prometheus Pushgateway
|
||||||
|
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
|
||||||
|
|
||||||
|
- name: Start Prometheus (for Pushgateway forwarding)
|
||||||
|
run: |
|
||||||
|
echo '
|
||||||
|
global:
|
||||||
|
scrape_interval: 15s
|
||||||
|
scrape_configs:
|
||||||
|
- job_name: "pushgateway"
|
||||||
|
static_configs:
|
||||||
|
- targets: ["pushgateway:9091"]
|
||||||
|
remote_write:
|
||||||
|
- url: ${{ secrets.GRAFANA_URL }}
|
||||||
|
basic_auth:
|
||||||
|
username: ${{ secrets.GRAFANA_USER }}
|
||||||
|
password: ${{ secrets.GRAFANA_API_KEY }}
|
||||||
|
' > prometheus.yml
|
||||||
|
|
||||||
|
docker run -d --name prometheus --network promnet \
|
||||||
|
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||||
|
-p 9090:9090 \
|
||||||
|
prom/prometheus
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
@@ -420,13 +488,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -447,11 +508,13 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
|
GIT_BRANCH=${{ github.ref_name }} \
|
||||||
go test -tags=benchmark \
|
go test -tags=benchmark \
|
||||||
-run=^$ \
|
-run=^$ \
|
||||||
-bench=. \
|
-bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
@@ -489,13 +552,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -505,89 +561,8 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags=integration \
|
go test -tags=integration \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
test_client_on_docker:
|
|
||||||
name: "Client (Docker) / Unit"
|
|
||||||
needs: [ build-cache ]
|
|
||||||
runs-on: ubuntu-20.04
|
|
||||||
steps:
|
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version: "1.23.x"
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
|
||||||
run: |
|
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache/restore@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
${{ env.cache }}
|
|
||||||
${{ env.modcache }}
|
|
||||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-gotest-cache-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install modules
|
|
||||||
run: go mod tidy
|
|
||||||
|
|
||||||
- name: check git status
|
|
||||||
run: git --no-pager diff --exit-code
|
|
||||||
|
|
||||||
- name: Generate Shared Sock Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
|
||||||
|
|
||||||
- name: Generate RouteManager Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
|
||||||
|
|
||||||
- name: Generate SystemOps Test bin
|
|
||||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
|
||||||
|
|
||||||
- name: Generate nftables Manager Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
|
||||||
|
|
||||||
- name: Generate Engine Test bin
|
|
||||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
|
||||||
|
|
||||||
- name: Generate Peer Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
|
|
||||||
|
|
||||||
- run: chmod +x *testing.bin
|
|
||||||
|
|
||||||
- name: Run Shared Sock tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Iface tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
|
||||||
|
|
||||||
- name: Run RouteManager tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run SystemOps tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run nftables Manager tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Engine tests in docker with file store
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Engine tests in docker with sqlite store
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Peer tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|||||||
1
.github/workflows/golangci-lint.yml
vendored
1
.github/workflows/golangci-lint.yml
vendored
@@ -21,7 +21,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
9
.github/workflows/release.yml
vendored
9
.github/workflows/release.yml
vendored
@@ -9,7 +9,7 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.18"
|
SIGN_PIPE_VER: "v0.0.20"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "NetBird GmbH"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
@@ -65,6 +65,13 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
username: ${{ secrets.DOCKER_USER }}
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
password: ${{ secrets.DOCKER_TOKEN }}
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
- name: Log in to the GitHub container registry
|
||||||
|
if: github.event_name != 'pull_request'
|
||||||
|
uses: docker/login-action@v3
|
||||||
|
with:
|
||||||
|
registry: ghcr.io
|
||||||
|
username: ${{ github.actor }}
|
||||||
|
password: ${{ secrets.CI_DOCKER_PUSH_GITHUB_TOKEN }}
|
||||||
- name: Install OS build dependencies
|
- name: Install OS build dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
run: sudo apt update && sudo apt install -y -q gcc-arm-linux-gnueabihf gcc-aarch64-linux-gnu
|
||||||
|
|
||||||
|
|||||||
@@ -134,6 +134,7 @@ jobs:
|
|||||||
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||||
|
CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY: false
|
||||||
|
|
||||||
run: |
|
run: |
|
||||||
set -x
|
set -x
|
||||||
@@ -172,13 +173,15 @@ jobs:
|
|||||||
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||||
# check relay values
|
# check relay values
|
||||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||||
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
||||||
grep '33445:33445' docker-compose.yml
|
grep '33445:33445' docker-compose.yml
|
||||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
|
||||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||||
grep DisablePromptLogin management.json | grep 'true'
|
grep DisablePromptLogin management.json | grep 'true'
|
||||||
|
grep LoginFlag management.json | grep 0
|
||||||
|
grep DisableDefaultPolicy management.json | grep "$CI_NETBIRD_MGMT_DISABLE_DEFAULT_POLICY"
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|||||||
198
.goreleaser.yaml
198
.goreleaser.yaml
@@ -96,6 +96,20 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-upload
|
||||||
|
dir: upload-server
|
||||||
|
env: [CGO_ENABLED=0]
|
||||||
|
binary: netbird-upload
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
universal_binaries:
|
universal_binaries:
|
||||||
- id: netbird
|
- id: netbird
|
||||||
|
|
||||||
@@ -135,6 +149,7 @@ nfpms:
|
|||||||
dockers:
|
dockers:
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-amd64
|
- netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -150,6 +165,7 @@ dockers:
|
|||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-arm64v8
|
- netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -161,10 +177,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-arm
|
- netbirdio/netbird:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -177,11 +194,12 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -193,9 +211,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -207,9 +227,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird
|
- netbird
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -222,10 +244,12 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-amd64
|
- netbirdio/relay:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird-relay
|
- netbird-relay
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -237,10 +261,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird-relay
|
- netbird-relay
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -252,10 +277,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-arm
|
- netbirdio/relay:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird-relay
|
- netbird-relay
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -268,10 +294,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/signal:{{ .Version }}-amd64
|
- netbirdio/signal:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird-signal
|
- netbird-signal
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -283,10 +310,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/signal:{{ .Version }}-arm64v8
|
- netbirdio/signal:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird-signal
|
- netbird-signal
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -298,10 +326,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/signal:{{ .Version }}-arm
|
- netbirdio/signal:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird-signal
|
- netbird-signal
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -314,10 +343,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-amd64
|
- netbirdio/management:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -329,10 +359,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-arm64v8
|
- netbirdio/management:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -344,10 +375,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-arm
|
- netbirdio/management:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -360,10 +392,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: amd64
|
goarch: amd64
|
||||||
@@ -375,10 +408,11 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm64
|
goarch: arm64
|
||||||
@@ -390,11 +424,12 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/management:{{ .Version }}-debug-arm
|
- netbirdio/management:{{ .Version }}-debug-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||||
ids:
|
ids:
|
||||||
- netbird-mgmt
|
- netbird-mgmt
|
||||||
goarch: arm
|
goarch: arm
|
||||||
@@ -407,7 +442,56 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
ids:
|
||||||
|
- netbird-upload
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: upload-server/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird-upload
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: upload-server/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||||
|
ids:
|
||||||
|
- netbird-upload
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: upload-server/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.source=https://github.com/netbirdio/{{.ProjectName}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
docker_manifests:
|
docker_manifests:
|
||||||
- name_template: netbirdio/netbird:{{ .Version }}
|
- name_template: netbirdio/netbird:{{ .Version }}
|
||||||
@@ -475,7 +559,95 @@ docker_manifests:
|
|||||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||||
- netbirdio/management:{{ .Version }}-debug-arm
|
- netbirdio/management:{{ .Version }}-debug-arm
|
||||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
|
- name_template: netbirdio/upload:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/upload:latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird:{{ .Version }}-rootless
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/netbird:rootless-latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- ghcr.io/netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/relay:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/relay:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/relay:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/signal:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/signal:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/signal:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/management:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/management:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/management:debug-latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-arm
|
||||||
|
- ghcr.io/netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/upload:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: ghcr.io/netbirdio/upload:latest
|
||||||
|
image_templates:
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- ghcr.io/netbirdio/upload:{{ .Version }}-amd64
|
||||||
brews:
|
brews:
|
||||||
- ids:
|
- ids:
|
||||||
- default
|
- default
|
||||||
|
|||||||
13
README.md
13
README.md
@@ -12,9 +12,12 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
<a href="https://docs.netbird.io/slack-url">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
</a>
|
||||||
|
<a href="https://forum.netbird.io">
|
||||||
|
<img src="https://img.shields.io/badge/community forum-@netbird-red.svg?logo=discourse"/>
|
||||||
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://gurubase.io/g/netbird">
|
<a href="https://gurubase.io/g/netbird">
|
||||||
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
<img src="https://img.shields.io/badge/Gurubase-Ask%20NetBird%20Guru-006BFF"/>
|
||||||
@@ -29,13 +32,13 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a> or our <a href="https://forum.netbird.io">Community forum</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://github.com/netbirdio/kubernetes-operator">
|
<a href="https://registry.terraform.io/providers/netbirdio/netbird/latest">
|
||||||
New: NetBird Kubernetes Operator
|
New: NetBird terraform provider
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
@@ -61,7 +64,7 @@
|
|||||||
|----|----|----|----|----|
|
|----|----|----|----|----|
|
||||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||||
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||||
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||||
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||||
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||||
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||||
|
|||||||
@@ -1,5 +1,9 @@
|
|||||||
FROM alpine:3.21.3
|
FROM alpine:3.21.3
|
||||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
|
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
||||||
|
|
||||||
|
ARG NETBIRD_BINARY=netbird
|
||||||
|
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||||
|
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||||
COPY netbird /usr/local/bin/netbird
|
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
FROM alpine:3.21.0
|
FROM alpine:3.21.0
|
||||||
|
|
||||||
COPY netbird /usr/local/bin/netbird
|
ARG NETBIRD_BINARY=netbird
|
||||||
|
COPY ${NETBIRD_BINARY} /usr/local/bin/netbird
|
||||||
|
|
||||||
RUN apk add --no-cache ca-certificates \
|
RUN apk add --no-cache ca-certificates \
|
||||||
&& adduser -D -h /var/lib/netbird netbird
|
&& adduser -D -h /var/lib/netbird netbird
|
||||||
|
|||||||
@@ -59,6 +59,8 @@ type Client struct {
|
|||||||
deviceName string
|
deviceName string
|
||||||
uiVersion string
|
uiVersion string
|
||||||
networkChangeListener listener.NetworkChangeListener
|
networkChangeListener listener.NetworkChangeListener
|
||||||
|
|
||||||
|
connectClient *internal.ConnectClient
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewClient instantiate a new Client
|
// NewClient instantiate a new Client
|
||||||
@@ -106,8 +108,8 @@ func (c *Client) Run(urlOpener URLOpener, dns *DNSList, dnsReadyListener DnsRead
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
// RunWithoutLogin we apply this type of run function when the backed has been started without UI (i.e. after reboot).
|
||||||
@@ -132,8 +134,8 @@ func (c *Client) RunWithoutLogin(dns *DNSList, dnsReadyListener DnsReadyListener
|
|||||||
|
|
||||||
// todo do not throw error in case of cancelled context
|
// todo do not throw error in case of cancelled context
|
||||||
ctx = internal.CtxInitState(ctx)
|
ctx = internal.CtxInitState(ctx)
|
||||||
connectClient := internal.NewConnectClient(ctx, cfg, c.recorder)
|
c.connectClient = internal.NewConnectClient(ctx, cfg, c.recorder)
|
||||||
return connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
return c.connectClient.RunOnAndroid(c.tunAdapter, c.iFaceDiscover, c.networkChangeListener, dns.items, dnsReadyListener)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Stop the internal client and free the resources
|
// Stop the internal client and free the resources
|
||||||
@@ -174,6 +176,55 @@ func (c *Client) PeersList() *PeerInfoArray {
|
|||||||
return &PeerInfoArray{items: peerInfos}
|
return &PeerInfoArray{items: peerInfos}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Client) Networks() *NetworkArray {
|
||||||
|
if c.connectClient == nil {
|
||||||
|
log.Error("not connected")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
engine := c.connectClient.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
log.Error("could not get engine")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
routeManager := engine.GetRouteManager()
|
||||||
|
if routeManager == nil {
|
||||||
|
log.Error("could not get route manager")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
networkArray := &NetworkArray{
|
||||||
|
items: make([]Network, 0),
|
||||||
|
}
|
||||||
|
|
||||||
|
for id, routes := range routeManager.GetClientRoutesWithNetID() {
|
||||||
|
if len(routes) == 0 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
r := routes[0]
|
||||||
|
netStr := r.Network.String()
|
||||||
|
if r.IsDynamic() {
|
||||||
|
netStr = r.Domains.SafeString()
|
||||||
|
}
|
||||||
|
|
||||||
|
peer, err := c.recorder.GetPeer(routes[0].Peer)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("could not get peer info for %s: %v", routes[0].Peer, err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
network := Network{
|
||||||
|
Name: string(id),
|
||||||
|
Network: netStr,
|
||||||
|
Peer: peer.FQDN,
|
||||||
|
Status: peer.ConnStatus.String(),
|
||||||
|
}
|
||||||
|
networkArray.Add(network)
|
||||||
|
}
|
||||||
|
return networkArray
|
||||||
|
}
|
||||||
|
|
||||||
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
// OnUpdatedHostDNS update the DNS servers addresses for root zones
|
||||||
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
func (c *Client) OnUpdatedHostDNS(list *DNSList) error {
|
||||||
dnsServer, err := dns.GetServerDns()
|
dnsServer, err := dns.GetServerDns()
|
||||||
|
|||||||
27
client/android/networks.go
Normal file
27
client/android/networks.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
//go:build android
|
||||||
|
|
||||||
|
package android
|
||||||
|
|
||||||
|
type Network struct {
|
||||||
|
Name string
|
||||||
|
Network string
|
||||||
|
Peer string
|
||||||
|
Status string
|
||||||
|
}
|
||||||
|
|
||||||
|
type NetworkArray struct {
|
||||||
|
items []Network
|
||||||
|
}
|
||||||
|
|
||||||
|
func (array *NetworkArray) Add(s Network) *NetworkArray {
|
||||||
|
array.items = append(array.items, s)
|
||||||
|
return array
|
||||||
|
}
|
||||||
|
|
||||||
|
func (array *NetworkArray) Get(i int) *Network {
|
||||||
|
return &array.items[i]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (array *NetworkArray) Size() int {
|
||||||
|
return len(array.items)
|
||||||
|
}
|
||||||
@@ -7,30 +7,23 @@ type PeerInfo struct {
|
|||||||
ConnStatus string // Todo replace to enum
|
ConnStatus string // Todo replace to enum
|
||||||
}
|
}
|
||||||
|
|
||||||
// PeerInfoCollection made for Java layer to get non default types as collection
|
// PeerInfoArray is a wrapper of []PeerInfo
|
||||||
type PeerInfoCollection interface {
|
|
||||||
Add(s string) PeerInfoCollection
|
|
||||||
Get(i int) string
|
|
||||||
Size() int
|
|
||||||
}
|
|
||||||
|
|
||||||
// PeerInfoArray is the implementation of the PeerInfoCollection
|
|
||||||
type PeerInfoArray struct {
|
type PeerInfoArray struct {
|
||||||
items []PeerInfo
|
items []PeerInfo
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add new PeerInfo to the collection
|
// Add new PeerInfo to the collection
|
||||||
func (array PeerInfoArray) Add(s PeerInfo) PeerInfoArray {
|
func (array *PeerInfoArray) Add(s PeerInfo) *PeerInfoArray {
|
||||||
array.items = append(array.items, s)
|
array.items = append(array.items, s)
|
||||||
return array
|
return array
|
||||||
}
|
}
|
||||||
|
|
||||||
// Get return an element of the collection
|
// Get return an element of the collection
|
||||||
func (array PeerInfoArray) Get(i int) *PeerInfo {
|
func (array *PeerInfoArray) Get(i int) *PeerInfo {
|
||||||
return &array.items[i]
|
return &array.items[i]
|
||||||
}
|
}
|
||||||
|
|
||||||
// Size return with the size of the collection
|
// Size return with the size of the collection
|
||||||
func (array PeerInfoArray) Size() int {
|
func (array *PeerInfoArray) Size() int {
|
||||||
return len(array.items)
|
return len(array.items)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Preferences export a subset of the internal config for gomobile
|
// Preferences exports a subset of the internal config for gomobile
|
||||||
type Preferences struct {
|
type Preferences struct {
|
||||||
configInput internal.ConfigInput
|
configInput internal.ConfigInput
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewPreferences create new Preferences instance
|
// NewPreferences creates a new Preferences instance
|
||||||
func NewPreferences(configPath string) *Preferences {
|
func NewPreferences(configPath string) *Preferences {
|
||||||
ci := internal.ConfigInput{
|
ci := internal.ConfigInput{
|
||||||
ConfigPath: configPath,
|
ConfigPath: configPath,
|
||||||
@@ -17,7 +17,7 @@ func NewPreferences(configPath string) *Preferences {
|
|||||||
return &Preferences{ci}
|
return &Preferences{ci}
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetManagementURL read url from config file
|
// GetManagementURL reads URL from config file
|
||||||
func (p *Preferences) GetManagementURL() (string, error) {
|
func (p *Preferences) GetManagementURL() (string, error) {
|
||||||
if p.configInput.ManagementURL != "" {
|
if p.configInput.ManagementURL != "" {
|
||||||
return p.configInput.ManagementURL, nil
|
return p.configInput.ManagementURL, nil
|
||||||
@@ -30,12 +30,12 @@ func (p *Preferences) GetManagementURL() (string, error) {
|
|||||||
return cfg.ManagementURL.String(), err
|
return cfg.ManagementURL.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetManagementURL store the given url and wait for commit
|
// SetManagementURL stores the given URL and waits for commit
|
||||||
func (p *Preferences) SetManagementURL(url string) {
|
func (p *Preferences) SetManagementURL(url string) {
|
||||||
p.configInput.ManagementURL = url
|
p.configInput.ManagementURL = url
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetAdminURL read url from config file
|
// GetAdminURL reads URL from config file
|
||||||
func (p *Preferences) GetAdminURL() (string, error) {
|
func (p *Preferences) GetAdminURL() (string, error) {
|
||||||
if p.configInput.AdminURL != "" {
|
if p.configInput.AdminURL != "" {
|
||||||
return p.configInput.AdminURL, nil
|
return p.configInput.AdminURL, nil
|
||||||
@@ -48,12 +48,12 @@ func (p *Preferences) GetAdminURL() (string, error) {
|
|||||||
return cfg.AdminURL.String(), err
|
return cfg.AdminURL.String(), err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetAdminURL store the given url and wait for commit
|
// SetAdminURL stores the given URL and waits for commit
|
||||||
func (p *Preferences) SetAdminURL(url string) {
|
func (p *Preferences) SetAdminURL(url string) {
|
||||||
p.configInput.AdminURL = url
|
p.configInput.AdminURL = url
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetPreSharedKey read preshared key from config file
|
// GetPreSharedKey reads pre-shared key from config file
|
||||||
func (p *Preferences) GetPreSharedKey() (string, error) {
|
func (p *Preferences) GetPreSharedKey() (string, error) {
|
||||||
if p.configInput.PreSharedKey != nil {
|
if p.configInput.PreSharedKey != nil {
|
||||||
return *p.configInput.PreSharedKey, nil
|
return *p.configInput.PreSharedKey, nil
|
||||||
@@ -66,12 +66,160 @@ func (p *Preferences) GetPreSharedKey() (string, error) {
|
|||||||
return cfg.PreSharedKey, err
|
return cfg.PreSharedKey, err
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetPreSharedKey store the given key and wait for commit
|
// SetPreSharedKey stores the given key and waits for commit
|
||||||
func (p *Preferences) SetPreSharedKey(key string) {
|
func (p *Preferences) SetPreSharedKey(key string) {
|
||||||
p.configInput.PreSharedKey = &key
|
p.configInput.PreSharedKey = &key
|
||||||
}
|
}
|
||||||
|
|
||||||
// Commit write out the changes into config file
|
// SetRosenpassEnabled stores whether Rosenpass is enabled
|
||||||
|
func (p *Preferences) SetRosenpassEnabled(enabled bool) {
|
||||||
|
p.configInput.RosenpassEnabled = &enabled
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRosenpassEnabled reads Rosenpass enabled status from config file
|
||||||
|
func (p *Preferences) GetRosenpassEnabled() (bool, error) {
|
||||||
|
if p.configInput.RosenpassEnabled != nil {
|
||||||
|
return *p.configInput.RosenpassEnabled, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.RosenpassEnabled, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetRosenpassPermissive stores the given permissive setting and waits for commit
|
||||||
|
func (p *Preferences) SetRosenpassPermissive(permissive bool) {
|
||||||
|
p.configInput.RosenpassPermissive = &permissive
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetRosenpassPermissive reads Rosenpass permissive setting from config file
|
||||||
|
func (p *Preferences) GetRosenpassPermissive() (bool, error) {
|
||||||
|
if p.configInput.RosenpassPermissive != nil {
|
||||||
|
return *p.configInput.RosenpassPermissive, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.RosenpassPermissive, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableClientRoutes reads disable client routes setting from config file
|
||||||
|
func (p *Preferences) GetDisableClientRoutes() (bool, error) {
|
||||||
|
if p.configInput.DisableClientRoutes != nil {
|
||||||
|
return *p.configInput.DisableClientRoutes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableClientRoutes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableClientRoutes stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableClientRoutes(disable bool) {
|
||||||
|
p.configInput.DisableClientRoutes = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableServerRoutes reads disable server routes setting from config file
|
||||||
|
func (p *Preferences) GetDisableServerRoutes() (bool, error) {
|
||||||
|
if p.configInput.DisableServerRoutes != nil {
|
||||||
|
return *p.configInput.DisableServerRoutes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableServerRoutes, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableServerRoutes stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableServerRoutes(disable bool) {
|
||||||
|
p.configInput.DisableServerRoutes = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableDNS reads disable DNS setting from config file
|
||||||
|
func (p *Preferences) GetDisableDNS() (bool, error) {
|
||||||
|
if p.configInput.DisableDNS != nil {
|
||||||
|
return *p.configInput.DisableDNS, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableDNS, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableDNS stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableDNS(disable bool) {
|
||||||
|
p.configInput.DisableDNS = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetDisableFirewall reads disable firewall setting from config file
|
||||||
|
func (p *Preferences) GetDisableFirewall() (bool, error) {
|
||||||
|
if p.configInput.DisableFirewall != nil {
|
||||||
|
return *p.configInput.DisableFirewall, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.DisableFirewall, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetDisableFirewall stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetDisableFirewall(disable bool) {
|
||||||
|
p.configInput.DisableFirewall = &disable
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetServerSSHAllowed reads server SSH allowed setting from config file
|
||||||
|
func (p *Preferences) GetServerSSHAllowed() (bool, error) {
|
||||||
|
if p.configInput.ServerSSHAllowed != nil {
|
||||||
|
return *p.configInput.ServerSSHAllowed, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
if cfg.ServerSSHAllowed == nil {
|
||||||
|
// Default to false for security on Android
|
||||||
|
return false, nil
|
||||||
|
}
|
||||||
|
return *cfg.ServerSSHAllowed, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetServerSSHAllowed stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetServerSSHAllowed(allowed bool) {
|
||||||
|
p.configInput.ServerSSHAllowed = &allowed
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetBlockInbound reads block inbound setting from config file
|
||||||
|
func (p *Preferences) GetBlockInbound() (bool, error) {
|
||||||
|
if p.configInput.BlockInbound != nil {
|
||||||
|
return *p.configInput.BlockInbound, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
cfg, err := internal.ReadConfig(p.configInput.ConfigPath)
|
||||||
|
if err != nil {
|
||||||
|
return false, err
|
||||||
|
}
|
||||||
|
return cfg.BlockInbound, err
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetBlockInbound stores the given value and waits for commit
|
||||||
|
func (p *Preferences) SetBlockInbound(block bool) {
|
||||||
|
p.configInput.BlockInbound = &block
|
||||||
|
}
|
||||||
|
|
||||||
|
// Commit writes out the changes to the config file
|
||||||
func (p *Preferences) Commit() error {
|
func (p *Preferences) Commit() error {
|
||||||
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
_, err := internal.UpdateOrCreateConfig(p.configInput)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
|||||||
return a.ipAnonymizer[ip]
|
return a.ipAnonymizer[ip]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
|
||||||
|
// Convert IP to netip.Addr
|
||||||
|
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||||
|
if !ok {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
anonIP := a.AnonymizeIP(ip)
|
||||||
|
|
||||||
|
return net.UDPAddr{
|
||||||
|
IP: anonIP.AsSlice(),
|
||||||
|
Port: addr.Port,
|
||||||
|
Zone: addr.Zone,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||||
|
|||||||
@@ -11,9 +11,12 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errCloseConnection = "Failed to close connection: %v"
|
const errCloseConnection = "Failed to close connection: %v"
|
||||||
@@ -84,16 +87,27 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||||
SystemInfo: debugSystemInfoFlag,
|
SystemInfo: debugSystemInfoFlag,
|
||||||
})
|
}
|
||||||
|
if debugUploadBundle {
|
||||||
|
request.UploadURL = debugUploadBundleURL
|
||||||
|
}
|
||||||
|
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||||
|
|
||||||
cmd.Println(resp.GetPath())
|
if resp.GetUploadFailureReason() != "" {
|
||||||
|
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||||
|
}
|
||||||
|
|
||||||
|
if debugUploadBundle {
|
||||||
|
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -208,23 +222,19 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||||
|
request := &proto.DebugBundleRequest{
|
||||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: statusOutput,
|
Status: statusOutput,
|
||||||
SystemInfo: debugSystemInfoFlag,
|
SystemInfo: debugSystemInfoFlag,
|
||||||
})
|
}
|
||||||
|
if debugUploadBundle {
|
||||||
|
request.UploadURL = debugUploadBundleURL
|
||||||
|
}
|
||||||
|
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable network map persistence after creating the debug bundle
|
|
||||||
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
|
||||||
Enabled: false,
|
|
||||||
}); err != nil {
|
|
||||||
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
if stateWasDown {
|
if stateWasDown {
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||||
@@ -239,7 +249,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Println(resp.GetPath())
|
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||||
|
|
||||||
|
if resp.GetUploadFailureReason() != "" {
|
||||||
|
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||||
|
}
|
||||||
|
|
||||||
|
if debugUploadBundle {
|
||||||
|
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -326,3 +344,34 @@ func formatDuration(d time.Duration) string {
|
|||||||
s := d / time.Second
|
s := d / time.Second
|
||||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
|
||||||
|
var networkMap *mgmProto.NetworkMap
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if connectClient != nil {
|
||||||
|
networkMap, err = connectClient.GetLatestNetworkMap()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get latest network map: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bundleGenerator := debug.NewBundleGenerator(
|
||||||
|
debug.GeneratorDependencies{
|
||||||
|
InternalConfig: config,
|
||||||
|
StatusRecorder: recorder,
|
||||||
|
NetworkMap: networkMap,
|
||||||
|
LogFile: logFilePath,
|
||||||
|
},
|
||||||
|
debug.BundleConfig{
|
||||||
|
IncludeSystemInfo: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
path, err := bundleGenerator.Generate()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to generate debug bundle: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||||
|
}
|
||||||
|
|||||||
39
client/cmd/debug_unix.go
Normal file
39
client/cmd/debug_unix.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetupDebugHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
) {
|
||||||
|
usr1Ch := make(chan os.Signal, 1)
|
||||||
|
|
||||||
|
signal.Notify(usr1Ch, syscall.SIGUSR1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-usr1Ch:
|
||||||
|
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
|
||||||
|
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
126
client/cmd/debug_windows.go
Normal file
126
client/cmd/debug_windows.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
|
||||||
|
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
|
||||||
|
|
||||||
|
waitTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
|
||||||
|
// Example usage with PowerShell:
|
||||||
|
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
|
||||||
|
// $evt.Set()
|
||||||
|
// $evt.Close()
|
||||||
|
func SetupDebugHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
) {
|
||||||
|
env := os.Getenv(envListenEvent)
|
||||||
|
if env == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
listenEvent, err := strconv.ParseBool(env)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !listenEvent {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: restrict access by ACL
|
||||||
|
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
||||||
|
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
|
||||||
|
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
||||||
|
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
|
||||||
|
} else {
|
||||||
|
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventHandle == windows.InvalidHandle {
|
||||||
|
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
|
||||||
|
|
||||||
|
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
eventHandle windows.Handle,
|
||||||
|
) {
|
||||||
|
defer func() {
|
||||||
|
if err := windows.CloseHandle(eventHandle); err != nil {
|
||||||
|
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case windows.WAIT_OBJECT_0:
|
||||||
|
log.Info("Received signal on debug event. Triggering debug bundle generation.")
|
||||||
|
|
||||||
|
// reset the event so it can be triggered again later (manual reset == 1)
|
||||||
|
if err := windows.ResetEvent(eventHandle); err != nil {
|
||||||
|
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||||
|
case uint32(windows.WAIT_TIMEOUT):
|
||||||
|
|
||||||
|
default:
|
||||||
|
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
|
||||||
|
select {
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -55,6 +56,9 @@ var loginCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update host's static platform and system information
|
||||||
|
system.UpdateStaticInfo()
|
||||||
|
|
||||||
ic := internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
@@ -95,11 +99,11 @@ var loginCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
DnsLabels: dnsLabelsReq,
|
DnsLabels: dnsLabelsReq,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
@@ -192,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -240,7 +244,10 @@ func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBro
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isLinuxRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,23 +22,26 @@ import (
|
|||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
externalIPMapFlag = "external-ip-map"
|
externalIPMapFlag = "external-ip-map"
|
||||||
dnsResolverAddress = "dns-resolver-address"
|
dnsResolverAddress = "dns-resolver-address"
|
||||||
enableRosenpassFlag = "enable-rosenpass"
|
enableRosenpassFlag = "enable-rosenpass"
|
||||||
rosenpassPermissiveFlag = "rosenpass-permissive"
|
rosenpassPermissiveFlag = "rosenpass-permissive"
|
||||||
preSharedKeyFlag = "preshared-key"
|
preSharedKeyFlag = "preshared-key"
|
||||||
interfaceNameFlag = "interface-name"
|
interfaceNameFlag = "interface-name"
|
||||||
wireguardPortFlag = "wireguard-port"
|
wireguardPortFlag = "wireguard-port"
|
||||||
networkMonitorFlag = "network-monitor"
|
networkMonitorFlag = "network-monitor"
|
||||||
disableAutoConnectFlag = "disable-auto-connect"
|
disableAutoConnectFlag = "disable-auto-connect"
|
||||||
serverSSHAllowedFlag = "allow-server-ssh"
|
serverSSHAllowedFlag = "allow-server-ssh"
|
||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
systemInfoFlag = "system-info"
|
systemInfoFlag = "system-info"
|
||||||
blockLANAccessFlag = "block-lan-access"
|
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||||
|
uploadBundle = "upload-bundle"
|
||||||
|
uploadBundleURL = "upload-bundle-url"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -74,7 +77,9 @@ var (
|
|||||||
anonymizeFlag bool
|
anonymizeFlag bool
|
||||||
debugSystemInfoFlag bool
|
debugSystemInfoFlag bool
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
blockLANAccess bool
|
debugUploadBundle bool
|
||||||
|
debugUploadBundleURL string
|
||||||
|
lazyConnEnabled bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
@@ -179,8 +184,11 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||||
|
|
||||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||||
|
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||||
|
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
@@ -27,12 +28,19 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newSVCConfig() *service.Config {
|
func newSVCConfig() *service.Config {
|
||||||
return &service.Config{
|
config := &service.Config{
|
||||||
Name: serviceName,
|
Name: serviceName,
|
||||||
DisplayName: "Netbird",
|
DisplayName: "Netbird",
|
||||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
Description: "Netbird mesh network client",
|
||||||
Option: make(service.KeyValue),
|
Option: make(service.KeyValue),
|
||||||
|
EnvVars: make(map[string]string),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
config.EnvVars["SYSTEMD_UNIT"] = serviceName
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
||||||
|
|||||||
@@ -16,12 +16,17 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *program) Start(svc service.Service) error {
|
func (p *program) Start(svc service.Service) error {
|
||||||
// Start should not block. Do the actual work async.
|
// Start should not block. Do the actual work async.
|
||||||
log.Info("starting Netbird service") //nolint
|
log.Info("starting Netbird service") //nolint
|
||||||
|
|
||||||
|
// Collect static system and platform information
|
||||||
|
system.UpdateStaticInfo()
|
||||||
|
|
||||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||||
p.serv = grpc.NewServer()
|
p.serv = grpc.NewServer()
|
||||||
|
|
||||||
@@ -115,6 +120,7 @@ var runCmd = &cobra.Command{
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
|
SetupDebugHandler(ctx, nil, nil, nil, logFile)
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
|
|||||||
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
if logFile != "console" {
|
if logFile != "" {
|
||||||
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func init() {
|
|||||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -69,7 +69,10 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if resp.GetStatus() == string(internal.StatusNeedsLogin) || resp.GetStatus() == string(internal.StatusLoginFailed) {
|
status := resp.GetStatus()
|
||||||
|
|
||||||
|
if status == string(internal.StatusNeedsLogin) || status == string(internal.StatusLoginFailed) ||
|
||||||
|
status == string(internal.StatusSessionExpired) {
|
||||||
cmd.Printf("Daemon status: %s\n\n"+
|
cmd.Printf("Daemon status: %s\n\n"+
|
||||||
"Run UP command to log in with SSO (interactive login):\n\n"+
|
"Run UP command to log in with SSO (interactive login):\n\n"+
|
||||||
" netbird up \n\n"+
|
" netbird up \n\n"+
|
||||||
@@ -117,7 +120,7 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
|
||||||
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true})
|
resp, err := proto.NewDaemonServiceClient(conn).Status(ctx, &proto.StatusRequest{GetFullPeerStatus: true, ShouldRunProbes: true})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
return nil, fmt.Errorf("status failed: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
@@ -127,12 +130,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "disconnected", "connected":
|
case "", "idle", "connecting", "connected":
|
||||||
if strings.ToLower(statusFilter) != "" {
|
if strings.ToLower(statusFilter) != "" {
|
||||||
enableDetailFlagWhenFilterFlag()
|
enableDetailFlagWhenFilterFlag()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ipsFilter) > 0 {
|
if len(ipsFilter) > 0 {
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ const (
|
|||||||
disableServerRoutesFlag = "disable-server-routes"
|
disableServerRoutesFlag = "disable-server-routes"
|
||||||
disableDNSFlag = "disable-dns"
|
disableDNSFlag = "disable-dns"
|
||||||
disableFirewallFlag = "disable-firewall"
|
disableFirewallFlag = "disable-firewall"
|
||||||
|
blockLANAccessFlag = "block-lan-access"
|
||||||
|
blockInboundFlag = "block-inbound"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -13,6 +15,8 @@ var (
|
|||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
disableDNS bool
|
disableDNS bool
|
||||||
disableFirewall bool
|
disableFirewall bool
|
||||||
|
blockLANAccess bool
|
||||||
|
blockInbound bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -28,4 +32,11 @@ func init() {
|
|||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
|
||||||
|
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||||
|
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||||
|
"This overrides any policies received from the management service.")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -98,7 +98,12 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
|||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
settingsMockManager.EXPECT().
|
||||||
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.Settings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
|||||||
Example: `
|
Example: `
|
||||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
||||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||||
Args: cobra.ExactArgs(3),
|
Args: cobra.ExactArgs(3),
|
||||||
RunE: tracePacket,
|
RunE: tracePacket,
|
||||||
@@ -118,7 +118,7 @@ func tracePacket(cmd *cobra.Command, args []string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||||
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
cmd.Printf("Packet trace %s:%d → %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||||
|
|
||||||
for _, stage := range resp.Stages {
|
for _, stage := range resp.Stages {
|
||||||
if stage.ForwardingDetails != nil {
|
if stage.ForwardingDetails != nil {
|
||||||
|
|||||||
254
client/cmd/up.go
254
client/cmd/up.go
@@ -55,12 +55,11 @@ func init() {
|
|||||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||||
)
|
)
|
||||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
|
||||||
|
|
||||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||||
`Sets DNS labels`+
|
`Sets DNS labels`+
|
||||||
@@ -119,6 +118,124 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ic, err := setupConfig(customDNSAddressConverted, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setup config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
providedSetupKey, err := getSetupKey()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := internal.UpdateOrCreateConfig(*ic)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||||
|
|
||||||
|
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("foreground login failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
|
SetupCloseHandler(ctx, cancel)
|
||||||
|
|
||||||
|
r := peer.NewRecorder(config.ManagementURL.String())
|
||||||
|
r.GetFullStatus()
|
||||||
|
|
||||||
|
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||||
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
|
return connectClient.Run(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||||
|
"If the daemon is not running please run: "+
|
||||||
|
"\nnetbird service install \nnetbird service start\n", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
|
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.Status == string(internal.StatusConnected) {
|
||||||
|
cmd.Println("Already connected")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
providedSetupKey, err := getSetupKey()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get setup key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setup login request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var loginErr error
|
||||||
|
var loginResp *proto.LoginResponse
|
||||||
|
|
||||||
|
err = WithBackOff(func() error {
|
||||||
|
var backOffErr error
|
||||||
|
loginResp, backOffErr = client.Login(ctx, loginRequest)
|
||||||
|
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||||
|
s.Code() == codes.PermissionDenied ||
|
||||||
|
s.Code() == codes.NotFound ||
|
||||||
|
s.Code() == codes.Unimplemented) {
|
||||||
|
loginErr = backOffErr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return backOffErr
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginErr != nil {
|
||||||
|
return fmt.Errorf("login failed: %v", loginErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginResp.NeedsSSOLogin {
|
||||||
|
|
||||||
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||||
|
return fmt.Errorf("call service up method: %v", err)
|
||||||
|
}
|
||||||
|
cmd.Println("Connected")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
|
||||||
ic := internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
@@ -143,7 +260,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
ic.InterfaceName = &interfaceName
|
ic.InterfaceName = &interfaceName
|
||||||
}
|
}
|
||||||
@@ -194,83 +311,29 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.BlockLANAccess = &blockLANAccess
|
ic.BlockLANAccess = &blockLANAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
if cmd.Flag(blockInboundFlag).Changed {
|
||||||
if err != nil {
|
ic.BlockInbound = &blockInbound
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(ic)
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
if err != nil {
|
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
return fmt.Errorf("get config file: %v", err)
|
|
||||||
}
|
}
|
||||||
|
return &ic, nil
|
||||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
|
||||||
|
|
||||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("foreground login failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
|
||||||
SetupCloseHandler(ctx, cancel)
|
|
||||||
|
|
||||||
r := peer.NewRecorder(config.ManagementURL.String())
|
|
||||||
r.GetFullStatus()
|
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
|
||||||
return connectClient.Run(nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
|
||||||
"If the daemon is not running please run: "+
|
|
||||||
"\nnetbird service install \nnetbird service start\n", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err := conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
|
|
||||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if status.Status == string(internal.StatusConnected) {
|
|
||||||
cmd.Println("Already connected")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
NatExternalIPs: natExternalIPs,
|
NatExternalIPs: natExternalIPs,
|
||||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||||
DnsLabels: dnsLabels,
|
DnsLabels: dnsLabels,
|
||||||
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
@@ -295,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
loginRequest.InterfaceName = &interfaceName
|
loginRequest.InterfaceName = &interfaceName
|
||||||
}
|
}
|
||||||
@@ -330,45 +393,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.BlockLanAccess = &blockLANAccess
|
loginRequest.BlockLanAccess = &blockLANAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
var loginErr error
|
if cmd.Flag(blockInboundFlag).Changed {
|
||||||
|
loginRequest.BlockInbound = &blockInbound
|
||||||
var loginResp *proto.LoginResponse
|
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
|
||||||
var backOffErr error
|
|
||||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
|
||||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
|
||||||
s.Code() == codes.PermissionDenied ||
|
|
||||||
s.Code() == codes.NotFound ||
|
|
||||||
s.Code() == codes.Unimplemented) {
|
|
||||||
loginErr = backOffErr
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return backOffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if loginErr != nil {
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
return fmt.Errorf("login failed: %v", loginErr)
|
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
}
|
}
|
||||||
|
return &loginRequest, nil
|
||||||
if loginResp.NeedsSSOLogin {
|
|
||||||
|
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
|
||||||
return fmt.Errorf("call service up method: %v", err)
|
|
||||||
}
|
|
||||||
cmd.Println("Connected")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNATExternalIPs(list []string) error {
|
func validateNATExternalIPs(list []string) error {
|
||||||
|
|||||||
@@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if !destination.Addr().Is4() {
|
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
@@ -148,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsStateful() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -199,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
firewall.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
@@ -220,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,6 +252,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return m.router.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateSet updates the set with the given prefixes
|
||||||
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.UpdateSet(set, prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package iptables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
ip := netip.MustParseAddr("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -57,18 +57,18 @@ type ruleInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type routeFilteringRuleParams struct {
|
type routeFilteringRuleParams struct {
|
||||||
Sources []netip.Prefix
|
Source firewall.Network
|
||||||
Destination netip.Prefix
|
Destination firewall.Network
|
||||||
Proto firewall.Protocol
|
Proto firewall.Protocol
|
||||||
SPort *firewall.Port
|
SPort *firewall.Port
|
||||||
DPort *firewall.Port
|
DPort *firewall.Port
|
||||||
Direction firewall.RuleDirection
|
Direction firewall.RuleDirection
|
||||||
Action firewall.Action
|
Action firewall.Action
|
||||||
SetName string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type routeRules map[string][]string
|
type routeRules map[string][]string
|
||||||
|
|
||||||
|
// the ipset library currently does not support comments, so we use the name only (string)
|
||||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
@@ -129,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
|||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
@@ -140,27 +140,28 @@ func (r *router) AddRouteFiltering(
|
|||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var setName string
|
var source firewall.Network
|
||||||
if len(sources) > 1 {
|
if len(sources) > 1 {
|
||||||
setName = firewall.GenerateSetName(sources)
|
source.Set = firewall.NewPrefixSet(sources)
|
||||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
} else if len(sources) > 0 {
|
||||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
source.Prefix = sources[0]
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
params := routeFilteringRuleParams{
|
params := routeFilteringRuleParams{
|
||||||
Sources: sources,
|
Source: source,
|
||||||
Destination: destination,
|
Destination: destination,
|
||||||
Proto: proto,
|
Proto: proto,
|
||||||
SPort: sPort,
|
SPort: sPort,
|
||||||
DPort: dPort,
|
DPort: dPort,
|
||||||
Action: action,
|
Action: action,
|
||||||
SetName: setName,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rule := genRouteFilteringRuleSpec(params)
|
rule, err := r.genRouteRuleSpec(params, sources)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate route rule spec: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
var err error
|
|
||||||
if action == firewall.ActionDrop {
|
if action == firewall.ActionDrop {
|
||||||
// after the established rule
|
// after the established rule
|
||||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||||
@@ -183,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
ruleKey := rule.ID()
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
setName := r.findSetNameInRule(rule)
|
|
||||||
|
|
||||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("delete route rule: %v", err)
|
return fmt.Errorf("delete route rule: %v", err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
if setName != "" {
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
return fmt.Errorf("failed to remove ipset: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("route rule %s not found", ruleKey)
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
@@ -204,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) findSetNameInRule(rule []string) string {
|
func (r *router) decrementSetCounter(rule []string) error {
|
||||||
for i, arg := range rule {
|
sets := r.findSets(rule)
|
||||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
var merr *multierror.Error
|
||||||
return rule[i+3]
|
for _, setName := range sets {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) findSets(rule []string) []string {
|
||||||
|
var sets []string
|
||||||
|
for i, arg := range rule {
|
||||||
|
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||||
|
sets = append(sets, rule[i+3])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||||
@@ -231,15 +241,13 @@ func (r *router) deleteIpSet(setName string) error {
|
|||||||
if err := ipset.Destroy(setName); err != nil {
|
if err := ipset.Destroy(setName); err != nil {
|
||||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("Deleted unused ipset %s", setName)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.legacyManagement {
|
if r.legacyManagement {
|
||||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
@@ -266,16 +274,14 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
if pair.Masquerade {
|
||||||
log.Errorf("%v", err)
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
}
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
|
||||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
@@ -313,8 +319,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
|
||||||
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -599,12 +607,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
rule = append(rule,
|
rule = append(rule,
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
"--ctstate", "NEW",
|
"--ctstate", "NEW",
|
||||||
"-s", pair.Source.String(),
|
)
|
||||||
"-d", pair.Destination.String(),
|
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply network -s: %w", err)
|
||||||
|
}
|
||||||
|
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply network -d: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule = append(rule, sourceExp...)
|
||||||
|
rule = append(rule, destExp...)
|
||||||
|
rule = append(rule,
|
||||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
// Ensure nat rules come first, so the mark can be overwritten.
|
||||||
|
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||||
|
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
|
||||||
|
// TODO: rollback ipset counter
|
||||||
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -622,6 +644,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("marking rule %s not found", ruleKey)
|
log.Debugf("marking rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
@@ -787,17 +813,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
|
||||||
var rule []string
|
var rule []string
|
||||||
|
|
||||||
if params.SetName != "" {
|
sourceExp, err := r.applyNetwork("-s", params.Source, sources)
|
||||||
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
if err != nil {
|
||||||
} else if len(params.Sources) > 0 {
|
return nil, fmt.Errorf("apply network -s: %w", err)
|
||||||
source := params.Sources[0]
|
|
||||||
rule = append(rule, "-s", source.String())
|
}
|
||||||
|
destExp, err := r.applyNetwork("-d", params.Destination, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rule = append(rule, "-d", params.Destination.String())
|
rule = append(rule, sourceExp...)
|
||||||
|
rule = append(rule, destExp...)
|
||||||
|
|
||||||
if params.Proto != firewall.ProtocolALL {
|
if params.Proto != firewall.ProtocolALL {
|
||||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||||
@@ -807,7 +837,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
|||||||
|
|
||||||
rule = append(rule, "-j", actionToStr(params.Action))
|
rule = append(rule, "-j", actionToStr(params.Action))
|
||||||
|
|
||||||
return rule
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||||
|
direction := "src"
|
||||||
|
if flag == "-d" {
|
||||||
|
direction = "dst"
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.IsSet() {
|
||||||
|
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||||
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||||
|
}
|
||||||
|
if network.IsPrefix() {
|
||||||
|
return []string{flag, network.Prefix.String()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:nilnil
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
// TODO: Implement IPv6 support
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if merr == nil {
|
||||||
|
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyPort(flag string, port *firewall.Port) []string {
|
func applyPort(flag string, port *firewall.Port) []string {
|
||||||
|
|||||||
@@ -60,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
|
|
||||||
pair := firewall.RouterPair{
|
pair := firewall.RouterPair{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -332,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
@@ -347,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
assert.NoError(t, err, "Failed to check rule existence")
|
assert.NoError(t, err, "Failed to check rule existence")
|
||||||
assert.True(t, exists, "Rule not found in iptables")
|
assert.True(t, exists, "Rule not found in iptables")
|
||||||
|
|
||||||
|
var source firewall.Network
|
||||||
|
if len(tt.sources) > 1 {
|
||||||
|
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||||
|
} else if len(tt.sources) > 0 {
|
||||||
|
source.Prefix = tt.sources[0]
|
||||||
|
}
|
||||||
// Verify rule content
|
// Verify rule content
|
||||||
params := routeFilteringRuleParams{
|
params := routeFilteringRuleParams{
|
||||||
Sources: tt.sources,
|
Source: source,
|
||||||
Destination: tt.destination,
|
Destination: firewall.Network{Prefix: tt.destination},
|
||||||
Proto: tt.proto,
|
Proto: tt.proto,
|
||||||
SPort: tt.sPort,
|
SPort: tt.sPort,
|
||||||
DPort: tt.dPort,
|
DPort: tt.dPort,
|
||||||
Action: tt.action,
|
Action: tt.action,
|
||||||
SetName: "",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedRule := genRouteFilteringRuleSpec(params)
|
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||||
|
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||||
|
|
||||||
if tt.expectSet {
|
if tt.expectSet {
|
||||||
setName := firewall.GenerateSetName(tt.sources)
|
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||||
params.SetName = setName
|
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||||
expectedRule = genRouteFilteringRuleSpec(params)
|
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||||
|
|
||||||
// Check if the set was created
|
// Check if the set was created
|
||||||
_, exists := r.ipsetCounter.Get(setName)
|
_, exists := r.ipsetCounter.Get(setName)
|
||||||
@@ -378,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindSetNameInRule(t *testing.T) {
|
||||||
|
r := &router{}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
rule []string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic rule with two sets",
|
||||||
|
rule: []string{
|
||||||
|
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
|
||||||
|
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
|
||||||
|
},
|
||||||
|
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No sets",
|
||||||
|
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple sets with different positions",
|
||||||
|
rule: []string{
|
||||||
|
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
|
||||||
|
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
|
||||||
|
},
|
||||||
|
expected: []string{"set1", "set-abc123"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Boundary case - sequence appears at end",
|
||||||
|
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
|
||||||
|
expected: []string{"final-set"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Incomplete pattern - missing set name",
|
||||||
|
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := r.findSets(tc.rule)
|
||||||
|
|
||||||
|
if len(result) != len(tc.expected) {
|
||||||
|
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, set := range result {
|
||||||
|
if set != tc.expected[i] {
|
||||||
|
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -43,6 +40,18 @@ const (
|
|||||||
// Action is the action to be taken on a rule
|
// Action is the action to be taken on a rule
|
||||||
type Action int
|
type Action int
|
||||||
|
|
||||||
|
// String returns the string representation of the action
|
||||||
|
func (a Action) String() string {
|
||||||
|
switch a {
|
||||||
|
case ActionAccept:
|
||||||
|
return "accept"
|
||||||
|
case ActionDrop:
|
||||||
|
return "drop"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ActionAccept is the action to accept a packet
|
// ActionAccept is the action to accept a packet
|
||||||
ActionAccept Action = iota
|
ActionAccept Action = iota
|
||||||
@@ -50,6 +59,33 @@ const (
|
|||||||
ActionDrop
|
ActionDrop
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Network is a rule destination, either a set or a prefix
|
||||||
|
type Network struct {
|
||||||
|
Set Set
|
||||||
|
Prefix netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the destination
|
||||||
|
func (d Network) String() string {
|
||||||
|
if d.Prefix.IsValid() {
|
||||||
|
return d.Prefix.String()
|
||||||
|
}
|
||||||
|
if d.IsSet() {
|
||||||
|
return d.Set.HashedName()
|
||||||
|
}
|
||||||
|
return "<invalid network>"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSet returns true if the destination is a set
|
||||||
|
func (d Network) IsSet() bool {
|
||||||
|
return d.Set != Set{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPrefix returns true if the destination is a valid prefix
|
||||||
|
func (d Network) IsPrefix() bool {
|
||||||
|
return d.Prefix.IsValid()
|
||||||
|
}
|
||||||
|
|
||||||
// Manager is the high level abstraction of a firewall manager
|
// Manager is the high level abstraction of a firewall manager
|
||||||
//
|
//
|
||||||
// It declares methods which handle actions required by the
|
// It declares methods which handle actions required by the
|
||||||
@@ -80,13 +116,14 @@ type Manager interface {
|
|||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
|
IsStateful() bool
|
||||||
|
|
||||||
AddRouteFiltering(
|
AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination Network,
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort, dPort *Port,
|
||||||
dPort *Port,
|
|
||||||
action Action,
|
action Action,
|
||||||
) (Rule, error)
|
) (Rule, error)
|
||||||
|
|
||||||
@@ -119,6 +156,9 @@ type Manager interface {
|
|||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
DeleteDNATRule(Rule) error
|
DeleteDNATRule(Rule) error
|
||||||
|
|
||||||
|
// UpdateSet updates the set with the given prefixes
|
||||||
|
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
@@ -153,22 +193,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
|
||||||
func GenerateSetName(sources []netip.Prefix) string {
|
|
||||||
// sort for consistent naming
|
|
||||||
SortPrefixes(sources)
|
|
||||||
|
|
||||||
var sourcesStr strings.Builder
|
|
||||||
for _, src := range sources {
|
|
||||||
sourcesStr.WriteString(src.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
|
||||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
|
||||||
|
|
||||||
return fmt.Sprintf("nb-%s", shortHash)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||||
if len(prefixes) == 0 {
|
if len(prefixes) == 0 {
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result1 := manager.GenerateSetName(prefixes1)
|
result1 := manager.NewPrefixSet(prefixes1)
|
||||||
result2 := manager.GenerateSetName(prefixes2)
|
result2 := manager.NewPrefixSet(prefixes2)
|
||||||
|
|
||||||
if result1 != result2 {
|
if result1 != result2 {
|
||||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||||
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
netip.MustParsePrefix("10.0.0.0/8"),
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result := manager.GenerateSetName(prefixes)
|
result := manager.NewPrefixSet(prefixes)
|
||||||
|
|
||||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error matching regex: %v", err)
|
t.Fatalf("Error matching regex: %v", err)
|
||||||
}
|
}
|
||||||
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
result1 := manager.NewPrefixSet([]netip.Prefix{})
|
||||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
result2 := manager.NewPrefixSet([]netip.Prefix{})
|
||||||
|
|
||||||
if result1 != result2 {
|
if result1 != result2 {
|
||||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||||
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result1 := manager.GenerateSetName(prefixes1)
|
result1 := manager.NewPrefixSet(prefixes1)
|
||||||
result2 := manager.GenerateSetName(prefixes2)
|
result2 := manager.NewPrefixSet(prefixes2)
|
||||||
|
|
||||||
if result1 != result2 {
|
if result1 != result2 {
|
||||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouterPair struct {
|
type RouterPair struct {
|
||||||
ID route.ID
|
ID route.ID
|
||||||
Source netip.Prefix
|
Source Network
|
||||||
Destination netip.Prefix
|
Destination Network
|
||||||
Masquerade bool
|
Masquerade bool
|
||||||
Inverse bool
|
Inverse bool
|
||||||
}
|
}
|
||||||
|
|||||||
74
client/firewall/manager/set.go
Normal file
74
client/firewall/manager/set.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Set struct {
|
||||||
|
hash [4]byte
|
||||||
|
comment string
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the set: hashed name and comment
|
||||||
|
func (h Set) String() string {
|
||||||
|
if h.comment == "" {
|
||||||
|
return h.HashedName()
|
||||||
|
}
|
||||||
|
return h.HashedName() + ": " + h.comment
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashedName returns the string representation of the hash
|
||||||
|
func (h Set) HashedName() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"nb-%s",
|
||||||
|
hex.EncodeToString(h.hash[:]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comment returns the comment of the set
|
||||||
|
func (h Set) Comment() string {
|
||||||
|
return h.comment
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||||
|
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||||
|
// sort for consistent naming
|
||||||
|
SortPrefixes(prefixes)
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
for _, src := range prefixes {
|
||||||
|
bytes, err := src.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to marshal prefix %s: %v", src, err)
|
||||||
|
}
|
||||||
|
hash.Write(bytes)
|
||||||
|
}
|
||||||
|
var set Set
|
||||||
|
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||||
|
|
||||||
|
return set
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDomainSet generates a unique name for an ipset based on the given domains.
|
||||||
|
func NewDomainSet(domains domain.List) Set {
|
||||||
|
slices.Sort(domains)
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
for _, d := range domains {
|
||||||
|
hash.Write([]byte(d.PunycodeString()))
|
||||||
|
}
|
||||||
|
set := Set{
|
||||||
|
comment: domains.SafeString(),
|
||||||
|
}
|
||||||
|
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||||
|
|
||||||
|
return set
|
||||||
|
}
|
||||||
@@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if !destination.Addr().Is4() {
|
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
@@ -171,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsStateful() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -242,7 +245,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close closes the firewall manager
|
||||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -325,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,6 +368,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return m.router.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateSet updates the set with the given prefixes
|
||||||
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.UpdateSet(set, prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: netip.MustParseAddr("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
expectedExprs2 := []expr.Any{
|
expectedExprs2 := []expr.Any{
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: add.AsSlice(),
|
Data: ip.AsSlice(),
|
||||||
},
|
},
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: netip.MustParseAddr("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
ip := netip.MustParseAddr("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -282,14 +273,14 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
netip.MustParsePrefix("10.1.0.0/24"),
|
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
&fw.Port{Values: []uint16{443}},
|
&fw.Port{Values: []uint16{443}},
|
||||||
@@ -298,8 +289,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
require.NoError(t, err, "failed to add route filtering rule")
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|
||||||
pair := fw.RouterPair{
|
pair := fw.RouterPair{
|
||||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
err = manager.AddNatRule(pair)
|
err = manager.AddNatRule(pair)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
@@ -44,9 +43,14 @@ const (
|
|||||||
const refreshRulesMapError = "refresh rules map: %w"
|
const refreshRulesMapError = "refresh rules map: %w"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type setInput struct {
|
||||||
|
set firewall.Set
|
||||||
|
prefixes []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
conn *nftables.Conn
|
conn *nftables.Conn
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
@@ -54,7 +58,7 @@ type router struct {
|
|||||||
chains map[string]*nftables.Chain
|
chains map[string]*nftables.Chain
|
||||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||||
rules map[string]*nftables.Rule
|
rules map[string]*nftables.Rule
|
||||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
@@ -163,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
|
|||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
return nil, fmt.Errorf("unable to list tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
@@ -316,7 +320,7 @@ func (r *router) setupDataPlaneMark() error {
|
|||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
@@ -331,23 +335,29 @@ func (r *router) AddRouteFiltering(
|
|||||||
chain := r.chains[chainNameRoutingFw]
|
chain := r.chains[chainNameRoutingFw]
|
||||||
var exprs []expr.Any
|
var exprs []expr.Any
|
||||||
|
|
||||||
|
var source firewall.Network
|
||||||
switch {
|
switch {
|
||||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||||
// If it's 0.0.0.0/0, we don't need to add any source matching
|
// If it's 0.0.0.0/0, we don't need to add any source matching
|
||||||
case len(sources) == 1:
|
case len(sources) == 1:
|
||||||
// If there's only one source, we can use it directly
|
// If there's only one source, we can use it directly
|
||||||
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
|
source.Prefix = sources[0]
|
||||||
default:
|
default:
|
||||||
// If there are multiple sources, create or get an ipset
|
// If there are multiple sources, use a set
|
||||||
var err error
|
source.Set = firewall.NewPrefixSet(sources)
|
||||||
exprs, err = r.getIpSetExprs(sources, exprs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get ipset expressions: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle destination
|
sourceExp, err := r.applyNetwork(source, sources, true)
|
||||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, sourceExp...)
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, destExp...)
|
||||||
|
|
||||||
// Handle protocol
|
// Handle protocol
|
||||||
if proto != firewall.ProtocolALL {
|
if proto != firewall.ProtocolALL {
|
||||||
@@ -391,39 +401,27 @@ func (r *router) AddRouteFiltering(
|
|||||||
rule = r.conn.AddRule(rule)
|
rule = r.conn.AddRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return nil, fmt.Errorf(flushError, err)
|
return nil, fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[string(ruleKey)] = rule
|
r.rules[string(ruleKey)] = rule
|
||||||
|
|
||||||
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
||||||
setName := firewall.GenerateSetName(sources)
|
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
||||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
set: set,
|
||||||
|
prefixes: prefixes,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs = append(exprs,
|
return getIpSetExprs(ref, isSource)
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Lookup{
|
|
||||||
SourceRegister: 1,
|
|
||||||
SetName: ref.Out.Name,
|
|
||||||
SetID: ref.Out.ID,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return exprs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
@@ -442,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
setName := r.findSetNameInRule(nftRule)
|
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
return fmt.Errorf("delete: %w", err)
|
return fmt.Errorf("delete: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if setName != "" {
|
|
||||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
|
||||||
return fmt.Errorf("decrement ipset reference: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
||||||
// overlapping prefixes will result in an error, so we need to merge them
|
// overlapping prefixes will result in an error, so we need to merge them
|
||||||
sources = firewall.MergeIPRanges(sources)
|
prefixes := firewall.MergeIPRanges(input.prefixes)
|
||||||
|
|
||||||
set := &nftables.Set{
|
nfset := &nftables.Set{
|
||||||
Name: setName,
|
Name: setName,
|
||||||
Table: r.workTable,
|
Comment: input.set.Comment(),
|
||||||
|
Table: r.workTable,
|
||||||
// required for prefixes
|
// required for prefixes
|
||||||
Interval: true,
|
Interval: true,
|
||||||
KeyType: nftables.TypeIPAddr,
|
KeyType: nftables.TypeIPAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
elements := convertPrefixesToSet(prefixes)
|
||||||
|
if err := r.conn.AddSet(nfset, elements); err != nil {
|
||||||
|
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||||
|
|
||||||
|
return nfset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||||
var elements []nftables.SetElement
|
var elements []nftables.SetElement
|
||||||
for _, prefix := range sources {
|
for _, prefix := range prefixes {
|
||||||
// TODO: Implement IPv6 support
|
// TODO: Implement IPv6 support
|
||||||
if prefix.Addr().Is6() {
|
if prefix.Addr().Is6() {
|
||||||
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -493,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
|
|||||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
return elements
|
||||||
if err := r.conn.AddSet(set, elements); err != nil {
|
|
||||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
|
||||||
|
|
||||||
return set, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculateLastIP determines the last IP in a given prefix.
|
// calculateLastIP determines the last IP in a given prefix.
|
||||||
@@ -528,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
|
||||||
r.conn.DelSet(set)
|
r.conn.DelSet(nfset)
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
@@ -538,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
|
func (r *router) decrementSetCounter(rule *nftables.Rule) error {
|
||||||
for _, e := range rule.Exprs {
|
sets := r.findSets(rule)
|
||||||
if lookup, ok := e.(*expr.Lookup); ok {
|
|
||||||
return lookup.SetName
|
var merr *multierror.Error
|
||||||
|
for _, setName := range sets {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) findSets(rule *nftables.Rule) []string {
|
||||||
|
var sets []string
|
||||||
|
for _, e := range rule.Exprs {
|
||||||
|
if lookup, ok := e.(*expr.Lookup); ok {
|
||||||
|
sets = append(sets, lookup.SetName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||||
@@ -560,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
|||||||
|
|
||||||
// AddNatRule appends a nftables rule pair to the nat chain
|
// AddNatRule appends a nftables rule pair to the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -586,7 +595,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
|
// TODO: rollback ipset counter
|
||||||
|
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -594,19 +604,22 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
op := expr.CmpOpEq
|
op := expr.CmpOpEq
|
||||||
if pair.Inverse {
|
if pair.Inverse {
|
||||||
op = expr.CmpOpNeq
|
op = expr.CmpOpNeq
|
||||||
}
|
}
|
||||||
|
|
||||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
exprs := []expr.Any{
|
||||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
|
||||||
exprs := getCtNewExprs()
|
|
||||||
exprs = append(exprs,
|
|
||||||
// interface matching
|
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
Key: expr.MetaKeyIIFNAME,
|
Key: expr.MetaKeyIIFNAME,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
@@ -616,7 +629,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(r.wgIface.Name()),
|
Data: ifname(r.wgIface.Name()),
|
||||||
},
|
},
|
||||||
)
|
}
|
||||||
|
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||||
|
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||||
|
exprs = append(exprs, getCtNewExprs()...)
|
||||||
|
|
||||||
exprs = append(exprs, sourceExp...)
|
exprs = append(exprs, sourceExp...)
|
||||||
exprs = append(exprs, destExp...)
|
exprs = append(exprs, destExp...)
|
||||||
@@ -646,7 +662,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
// Ensure nat rules come first, so the mark can be overwritten.
|
||||||
|
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||||
|
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNameManglePrerouting],
|
Chain: r.chains[chainNameManglePrerouting],
|
||||||
Exprs: exprs,
|
Exprs: exprs,
|
||||||
@@ -729,8 +747,15 @@ func (r *router) addPostroutingRules() error {
|
|||||||
|
|
||||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
&expr.Counter{},
|
&expr.Counter{},
|
||||||
@@ -739,7 +764,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
exprs = append(exprs, sourceExp...)
|
||||||
|
exprs = append(exprs, destExp...)
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
@@ -752,7 +778,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
Exprs: expression,
|
Exprs: exprs,
|
||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
@@ -767,11 +793,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
|
||||||
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -974,20 +1002,18 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
|||||||
|
|
||||||
// RemoveNatRule removes the prerouting mark rule
|
// RemoveNatRule removes the prerouting mark rule
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
|
||||||
log.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if pair.Masquerade {
|
||||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
}
|
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
@@ -995,10 +1021,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
// TODO: rollback set counter
|
||||||
|
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1006,16 +1032,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
err := r.conn.DelRule(rule)
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -1027,7 +1056,7 @@ func (r *router) refreshRulesMap() error {
|
|||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
return fmt.Errorf(" unable to list rules: %v", err)
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
@@ -1301,13 +1330,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
||||||
var offset uint32
|
if err != nil {
|
||||||
if source {
|
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||||
offset = 12 // src offset
|
}
|
||||||
} else {
|
|
||||||
offset = 16 // dst offset
|
elements := convertPrefixesToSet(prefixes)
|
||||||
|
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||||
|
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||||
|
func (r *router) applyNetwork(
|
||||||
|
network firewall.Network,
|
||||||
|
setPrefixes []netip.Prefix,
|
||||||
|
isSource bool,
|
||||||
|
) ([]expr.Any, error) {
|
||||||
|
if network.IsSet() {
|
||||||
|
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("source: %w", err)
|
||||||
|
}
|
||||||
|
return exprs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.IsPrefix() {
|
||||||
|
return applyPrefix(network.Prefix, isSource), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||||
|
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||||
|
// dst offset
|
||||||
|
offset := uint32(16)
|
||||||
|
if isSource {
|
||||||
|
// src offset
|
||||||
|
offset = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
ones := prefix.Bits()
|
ones := prefix.Bits()
|
||||||
@@ -1415,3 +1485,27 @@ func getCtNewExprs() []expr.Any {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||||
|
|
||||||
|
// dst offset
|
||||||
|
offset := uint32(16)
|
||||||
|
if isSource {
|
||||||
|
// src offset
|
||||||
|
offset = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: offset,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ref.Out.Name,
|
||||||
|
SetID: ref.Out.ID,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build CIDR matching expressions
|
// Build CIDR matching expressions
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||||
|
|
||||||
// Combine all expressions in the correct order
|
// Combine all expressions in the correct order
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
setName := firewall.GenerateSetName(tt.sources)
|
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||||
set, err := r.createIpSet(setName, tt.sources)
|
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("Failed to create IP set: %v", err)
|
t.Logf("Failed to create IP set: %v", err)
|
||||||
printNftSets()
|
printNftSets()
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ var (
|
|||||||
Name: "Insert Forwarding IPV4 Rule",
|
Name: "Insert Forwarding IPV4 Rule",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -24,8 +24,8 @@ var (
|
|||||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -40,8 +40,8 @@ var (
|
|||||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,7 +21,7 @@ const (
|
|||||||
firewallRuleName = "Netbird"
|
firewallRuleName = "Netbird"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||||
func (m *Manager) Close(*statemanager.Manager) error {
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
|||||||
@@ -62,5 +62,5 @@ type ConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (c ConnKey) String() string {
|
func (c ConnKey) String() string {
|
||||||
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
return fmt.Sprintf("%s:%d → %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package conntrack
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
@@ -19,6 +20,10 @@ const (
|
|||||||
DefaultICMPTimeout = 30 * time.Second
|
DefaultICMPTimeout = 30 * time.Second
|
||||||
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||||
ICMPCleanupInterval = 15 * time.Second
|
ICMPCleanupInterval = 15 * time.Second
|
||||||
|
|
||||||
|
// MaxICMPPayloadLength is the maximum length of ICMP payload we consider for original packet info,
|
||||||
|
// which includes the IP header (20 bytes) and transport header (8 bytes)
|
||||||
|
MaxICMPPayloadLength = 28
|
||||||
)
|
)
|
||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
@@ -29,7 +34,7 @@ type ICMPConnKey struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (i ICMPConnKey) String() string {
|
func (i ICMPConnKey) String() string {
|
||||||
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
return fmt.Sprintf("%s → %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPConnTrack represents an ICMP connection state
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
@@ -50,6 +55,72 @@ type ICMPTracker struct {
|
|||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ICMPInfo holds ICMP type, code, and payload for lazy string formatting in logs
|
||||||
|
type ICMPInfo struct {
|
||||||
|
TypeCode layers.ICMPv4TypeCode
|
||||||
|
PayloadData [MaxICMPPayloadLength]byte
|
||||||
|
// actual length of valid data
|
||||||
|
PayloadLen int
|
||||||
|
}
|
||||||
|
|
||||||
|
// String implements fmt.Stringer for lazy evaluation in log messages
|
||||||
|
func (info ICMPInfo) String() string {
|
||||||
|
if info.isErrorMessage() && info.PayloadLen >= MaxICMPPayloadLength {
|
||||||
|
if origInfo := info.parseOriginalPacket(); origInfo != "" {
|
||||||
|
return fmt.Sprintf("%s (original: %s)", info.TypeCode, origInfo)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return info.TypeCode.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
// isErrorMessage returns true if this ICMP type carries original packet info
|
||||||
|
func (info ICMPInfo) isErrorMessage() bool {
|
||||||
|
typ := info.TypeCode.Type()
|
||||||
|
return typ == 3 || // Destination Unreachable
|
||||||
|
typ == 5 || // Redirect
|
||||||
|
typ == 11 || // Time Exceeded
|
||||||
|
typ == 12 // Parameter Problem
|
||||||
|
}
|
||||||
|
|
||||||
|
// parseOriginalPacket extracts info about the original packet from ICMP payload
|
||||||
|
func (info ICMPInfo) parseOriginalPacket() string {
|
||||||
|
if info.PayloadLen < MaxICMPPayloadLength {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: handle IPv6
|
||||||
|
if version := (info.PayloadData[0] >> 4) & 0xF; version != 4 {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol := info.PayloadData[9]
|
||||||
|
srcIP := net.IP(info.PayloadData[12:16])
|
||||||
|
dstIP := net.IP(info.PayloadData[16:20])
|
||||||
|
|
||||||
|
transportData := info.PayloadData[20:]
|
||||||
|
|
||||||
|
switch nftypes.Protocol(protocol) {
|
||||||
|
case nftypes.TCP:
|
||||||
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
|
return fmt.Sprintf("TCP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
case nftypes.UDP:
|
||||||
|
srcPort := uint16(transportData[0])<<8 | uint16(transportData[1])
|
||||||
|
dstPort := uint16(transportData[2])<<8 | uint16(transportData[3])
|
||||||
|
return fmt.Sprintf("UDP %s:%d → %s:%d", srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
case nftypes.ICMP:
|
||||||
|
icmpType := transportData[0]
|
||||||
|
icmpCode := transportData[1]
|
||||||
|
return fmt.Sprintf("ICMP %s → %s (type %d code %d)", srcIP, dstIP, icmpType, icmpCode)
|
||||||
|
|
||||||
|
default:
|
||||||
|
return fmt.Sprintf("Proto %d %s → %s", protocol, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
@@ -93,30 +164,64 @@ func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP connection
|
// TrackOutbound records an outbound ICMP connection
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
func (t *ICMPTracker) TrackOutbound(
|
||||||
|
srcIP netip.Addr,
|
||||||
|
dstIP netip.Addr,
|
||||||
|
id uint16,
|
||||||
|
typecode layers.ICMPv4TypeCode,
|
||||||
|
payload []byte,
|
||||||
|
size int,
|
||||||
|
) {
|
||||||
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||||
// if (inverted direction) conn is not tracked, track this direction
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackInbound records an inbound ICMP Echo Request
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
func (t *ICMPTracker) TrackInbound(
|
||||||
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
srcIP netip.Addr,
|
||||||
|
dstIP netip.Addr,
|
||||||
|
id uint16,
|
||||||
|
typecode layers.ICMPv4TypeCode,
|
||||||
|
ruleId []byte,
|
||||||
|
payload []byte,
|
||||||
|
size int,
|
||||||
|
) {
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, payload, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
func (t *ICMPTracker) track(
|
||||||
|
srcIP netip.Addr,
|
||||||
|
dstIP netip.Addr,
|
||||||
|
id uint16,
|
||||||
|
typecode layers.ICMPv4TypeCode,
|
||||||
|
direction nftypes.Direction,
|
||||||
|
ruleId []byte,
|
||||||
|
payload []byte,
|
||||||
|
size int,
|
||||||
|
) {
|
||||||
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||||
if exists {
|
if exists {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
typ, code := typecode.Type(), typecode.Code()
|
typ, code := typecode.Type(), typecode.Code()
|
||||||
|
icmpInfo := ICMPInfo{
|
||||||
|
TypeCode: typecode,
|
||||||
|
}
|
||||||
|
if len(payload) > 0 {
|
||||||
|
icmpInfo.PayloadLen = len(payload)
|
||||||
|
if icmpInfo.PayloadLen > MaxICMPPayloadLength {
|
||||||
|
icmpInfo.PayloadLen = MaxICMPPayloadLength
|
||||||
|
}
|
||||||
|
copy(icmpInfo.PayloadData[:], payload[:icmpInfo.PayloadLen])
|
||||||
|
}
|
||||||
|
|
||||||
// non echo requests don't need tracking
|
// non echo requests don't need tracking
|
||||||
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||||
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -138,7 +243,7 @@ func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typec
|
|||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
|
|
||||||
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
t.logger.Trace("New %s ICMP connection %s - %s", direction, key, icmpInfo)
|
||||||
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, []byte{}, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ func BenchmarkICMPTracker(b *testing.B) {
|
|||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, []byte{}, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
|
|||||||
@@ -39,8 +39,12 @@ const (
|
|||||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
||||||
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
// Default off as it might be security risk because sockets listening on localhost only will become accessible.
|
||||||
|
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
|
||||||
|
|
||||||
|
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
|
||||||
|
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
|
||||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,10 +53,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
|||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]PeerRule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
type RouteRules []RouteRule
|
type RouteRules []*RouteRule
|
||||||
|
|
||||||
func (r RouteRules) Sort() {
|
func (r RouteRules) Sort() {
|
||||||
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
||||||
// Deny rules come first
|
// Deny rules come first
|
||||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||||
return -1
|
return -1
|
||||||
@@ -71,7 +75,6 @@ type Manager struct {
|
|||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -99,6 +102,14 @@ type Manager struct {
|
|||||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
|
|
||||||
|
blockRule firewall.Rule
|
||||||
|
|
||||||
|
// Internal 1:1 DNAT
|
||||||
|
dnatEnabled atomic.Bool
|
||||||
|
dnatMappings map[netip.Addr]netip.Addr
|
||||||
|
dnatMutex sync.RWMutex
|
||||||
|
dnatBiMap *biDNATMap
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -146,6 +157,11 @@ func parseCreateEnv() (bool, bool) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||||
}
|
}
|
||||||
|
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
|
||||||
|
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return disableConntrack, enableLocalForwarding
|
return disableConntrack, enableLocalForwarding
|
||||||
@@ -179,6 +195,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
|
dnatMappings: make(map[netip.Addr]netip.Addr),
|
||||||
}
|
}
|
||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
@@ -201,41 +218,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.blockInvalidRouted(iface); err != nil {
|
|
||||||
log.Errorf("failed to block invalid routed traffic: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
return nil, fmt.Errorf("set filter: %w", err)
|
return nil, fmt.Errorf("set filter: %w", err)
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||||
if m.forwarder.Load() == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse wireguard network: %w", err)
|
return nil, fmt.Errorf("parse wireguard network: %w", err)
|
||||||
}
|
}
|
||||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
if _, err := m.AddRouteFiltering(
|
rule, err := m.addRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
wgPrefix,
|
firewall.Network{Prefix: wgPrefix},
|
||||||
firewall.ProtocolALL,
|
firewall.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionDrop,
|
firewall.ActionDrop,
|
||||||
); err != nil {
|
)
|
||||||
return fmt.Errorf("block wg nte : %w", err)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("block wg nte : %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Block networks that we're a client of
|
// TODO: Block networks that we're a client of
|
||||||
|
|
||||||
return nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) determineRouting() error {
|
func (m *Manager) determineRouting() error {
|
||||||
@@ -273,7 +284,7 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
log.Info("userspace routing is forced")
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
case !m.netstack && m.nativeFirewall != nil:
|
||||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
// netstack mode won't support native routing as there is no interface
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
@@ -330,6 +341,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsStateful() bool {
|
||||||
|
return m.stateful
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
@@ -413,10 +428,23 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) addRouteFiltering(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination firewall.Network,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort, dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
@@ -426,34 +454,39 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
ruleID := uuid.New().String()
|
ruleID := uuid.New().String()
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
// TODO: consolidate these IDs
|
// TODO: consolidate these IDs
|
||||||
id: ruleID,
|
id: ruleID,
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
destination: destination,
|
dstSet: destination.Set,
|
||||||
proto: proto,
|
proto: proto,
|
||||||
srcPort: sPort,
|
srcPort: sPort,
|
||||||
dstPort: dPort,
|
dstPort: dPort,
|
||||||
action: action,
|
action: action,
|
||||||
|
}
|
||||||
|
if destination.IsPrefix() {
|
||||||
|
rule.destinations = []netip.Prefix{destination.Prefix}
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules = append(m.routeRules, rule)
|
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.deleteRouteRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
ruleID := rule.ID()
|
ruleID := rule.ID()
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||||
return r.id == ruleID
|
return r.id == ruleID
|
||||||
})
|
})
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
@@ -493,30 +526,60 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
// AddDNATRule adds a DNAT rule
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
if m.nativeFirewall == nil {
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
return nil, errNatNotSupported
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.UpdateSet(set, prefixes)
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.AddDNATRule(rule)
|
|
||||||
}
|
|
||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
m.mutex.Lock()
|
||||||
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
defer m.mutex.Unlock()
|
||||||
if m.nativeFirewall == nil {
|
|
||||||
return errNatNotSupported
|
var matches []*RouteRule
|
||||||
|
for _, rule := range m.routeRules {
|
||||||
|
if rule.dstSet == set {
|
||||||
|
matches = append(matches, rule)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return fmt.Errorf("no route rule found for set: %s", set)
|
||||||
|
}
|
||||||
|
|
||||||
|
destinations := matches[0].destinations
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
destinations = append(destinations, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
||||||
|
cmp := a.Addr().Compare(b.Addr())
|
||||||
|
if cmp != 0 {
|
||||||
|
return cmp
|
||||||
|
}
|
||||||
|
return a.Bits() - b.Bits()
|
||||||
|
})
|
||||||
|
|
||||||
|
destinations = slices.Compact(destinations)
|
||||||
|
|
||||||
|
for _, rule := range matches {
|
||||||
|
rule.destinations = destinations
|
||||||
|
}
|
||||||
|
log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations)
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// FilterOutBound filters outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
func (m *Manager) FilterOutbound(packetData []byte, size int) bool {
|
||||||
return m.processOutgoingHooks(packetData, size)
|
return m.filterOutbound(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// FilterInbound filters incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
func (m *Manager) FilterInbound(packetData []byte, size int) bool {
|
||||||
return m.dropFilter(packetData, size)
|
return m.filterInbound(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs updates the list of local IPs
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
@@ -524,7 +587,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
|||||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
func (m *Manager) filterOutbound(packetData []byte, size int) bool {
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -546,9 +609,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.stateful {
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
m.trackOutbound(d, srcIP, dstIP, size)
|
m.translateOutboundDNAT(packetData, d)
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -600,7 +662,7 @@ func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
|||||||
flags := getTCPFlags(&d.tcp)
|
flags := getTCPFlags(&d.tcp)
|
||||||
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, d.icmp4.Payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -613,7 +675,7 @@ func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byt
|
|||||||
flags := getTCPFlags(&d.tcp)
|
flags := getTCPFlags(&d.tcp)
|
||||||
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, d.icmp4.Payload, size)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -652,9 +714,9 @@ func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets.
|
// filterInbound implements filtering logic for incoming packets.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
func (m *Manager) filterInbound(packetData []byte, size int) bool {
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -676,8 +738,15 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// For all inbound traffic, first check if it matches a tracked connection.
|
if translated := m.translateInboundReverse(packetData, d); translated {
|
||||||
// This must happen before any other filtering because the packets are statefully tracked.
|
// Re-decode after translation to get original addresses
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
m.logger.Error("Failed to re-decode packet after reverse DNAT: %v", err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
srcIP, dstIP = m.extractIPs(d)
|
||||||
|
}
|
||||||
|
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -717,9 +786,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// if running in netstack mode we need to pass this to the forwarder
|
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||||
if m.netstack && m.localForwarding {
|
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||||
return m.handleNetstackLocalTraffic(packetData)
|
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||||
|
return m.handleForwardedLocalTraffic(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
// track inbound packets to get the correct direction and session id for flows
|
// track inbound packets to get the correct direction and session id for flows
|
||||||
@@ -729,8 +799,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||||
|
|
||||||
fwd := m.forwarder.Load()
|
fwd := m.forwarder.Load()
|
||||||
if fwd == nil {
|
if fwd == nil {
|
||||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||||
@@ -764,7 +833,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
|||||||
proto, pnum := getProtocolFromPacket(d)
|
proto, pnum := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
if !pass {
|
||||||
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
@@ -790,8 +860,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
|||||||
if fwd == nil {
|
if fwd == nil {
|
||||||
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
|
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
|
||||||
} else {
|
} else {
|
||||||
|
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
|
||||||
|
|
||||||
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
||||||
m.logger.Error("Failed to inject routed packet: %v", err)
|
m.logger.Error("Failed to inject routed packet: %v", err)
|
||||||
|
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -988,8 +1061,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
if !rule.destination.Contains(dstAddr) {
|
destMatched := false
|
||||||
|
for _, dst := range rule.destinations {
|
||||||
|
if dst.Contains(dstAddr) {
|
||||||
|
destMatched = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !destMatched {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -1017,11 +1097,6 @@ func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
|
||||||
m.wgNetwork = network
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
@@ -1091,7 +1166,22 @@ func (m *Manager) EnableRouting() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.determineRouting()
|
if err := m.determineRouting(); err != nil {
|
||||||
|
return fmt.Errorf("determine routing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.forwarder.Load() == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := m.blockInvalidRouted(m.wgIface)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("block invalid routed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.blockRule = rule
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
@@ -1116,5 +1206,12 @@ func (m *Manager) DisableRouting() error {
|
|||||||
|
|
||||||
log.Debug("forwarder stopped")
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
if m.blockRule != nil {
|
||||||
|
if err := m.deleteRouteRule(m.blockRule); err != nil {
|
||||||
|
return fmt.Errorf("delete block rule: %w", err)
|
||||||
|
}
|
||||||
|
m.blockRule = nil
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply scenario-specific setup
|
// Apply scenario-specific setup
|
||||||
sc.setupFunc(manager)
|
sc.setupFunc(manager)
|
||||||
|
|
||||||
@@ -193,13 +188,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
|
|
||||||
// For stateful scenarios, establish the connection
|
// For stateful scenarios, establish the connection
|
||||||
if sc.stateful {
|
if sc.stateful {
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, 0)
|
manager.filterInbound(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -219,18 +214,13 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-populate connection table
|
// Pre-populate connection table
|
||||||
srcIPs := generateRandomIPs(count)
|
srcIPs := generateRandomIPs(count)
|
||||||
dstIPs := generateRandomIPs(count)
|
dstIPs := generateRandomIPs(count)
|
||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test packet
|
// Test packet
|
||||||
@@ -238,11 +228,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
// First establish our test connection
|
// First establish our test connection
|
||||||
manager.processOutgoingHooks(testOut, 0)
|
manager.filterOutbound(testOut, 0)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn, 0)
|
manager.filterInbound(testIn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -267,23 +257,18 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
srcIP := generateRandomIPs(1)[0]
|
srcIP := generateRandomIPs(1)[0]
|
||||||
dstIP := generateRandomIPs(1)[0]
|
dstIP := generateRandomIPs(1)[0]
|
||||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
if sc.established {
|
if sc.established {
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, 0)
|
manager.filterInbound(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "post_handshake",
|
state: "post_handshake",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -477,25 +426,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
// For stateful cases and established connections
|
// For stateful cases and established connections
|
||||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
manager.processOutgoingHooks(outbound, 0)
|
manager.filterOutbound(outbound, 0)
|
||||||
|
|
||||||
// For TCP post-handshake, simulate full handshake
|
// For TCP post-handshake, simulate full handshake
|
||||||
if sc.state == "post_handshake" {
|
if sc.state == "post_handshake" {
|
||||||
// SYN
|
// SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn, 0)
|
manager.filterOutbound(syn, 0)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, 0)
|
manager.filterInbound(synack, 0)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack, 0)
|
manager.filterOutbound(ack, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound, 0)
|
manager.filterInbound(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@@ -624,17 +568,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Initial SYN
|
// Initial SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn, 0)
|
manager.filterOutbound(syn, 0)
|
||||||
|
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, 0)
|
manager.filterInbound(synack, 0)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack, 0)
|
manager.filterOutbound(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare test packets simulating bidirectional traffic
|
// Prepare test packets simulating bidirectional traffic
|
||||||
@@ -655,9 +599,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
manager.filterOutbound(outPackets[connIdx], 0)
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx], 0)
|
manager.filterInbound(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@@ -761,19 +700,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn, 0)
|
manager.filterOutbound(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck, 0)
|
manager.filterInbound(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack, 0)
|
manager.filterOutbound(p.ack, 0)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request, 0)
|
manager.filterOutbound(p.request, 0)
|
||||||
manager.dropFilter(p.response, 0)
|
manager.filterInbound(p.response, 0)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient, 0)
|
manager.filterOutbound(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer, 0)
|
manager.filterInbound(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer, 0)
|
manager.filterInbound(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient, 0)
|
manager.filterOutbound(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
@@ -826,15 +760,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
for i := 0; i < sc.connCount; i++ {
|
for i := 0; i < sc.connCount; i++ {
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn, 0)
|
manager.filterOutbound(syn, 0)
|
||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack, 0)
|
manager.filterInbound(synack, 0)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack, 0)
|
manager.filterOutbound(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-generate test packets
|
// Pre-generate test packets
|
||||||
@@ -856,8 +790,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
counter++
|
counter++
|
||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
manager.filterOutbound(outPackets[connIdx], 0)
|
||||||
manager.dropFilter(inPackets[connIdx], 0)
|
manager.filterInbound(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
@@ -950,17 +879,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn, 0)
|
manager.filterOutbound(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck, 0)
|
manager.filterInbound(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack, 0)
|
manager.filterOutbound(p.ack, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request, 0)
|
manager.filterOutbound(p.request, 0)
|
||||||
manager.dropFilter(p.response, 0)
|
manager.filterInbound(p.response, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient, 0)
|
manager.filterOutbound(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer, 0)
|
manager.filterInbound(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer, 0)
|
manager.filterInbound(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient, 0)
|
manager.filterOutbound(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
dst := fw.Network{Prefix: r.dest}
|
||||||
|
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -15,15 +15,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPeerACLFiltering(t *testing.T) {
|
func TestPeerACLFiltering(t *testing.T) {
|
||||||
localIP := net.ParseIP("100.10.0.100")
|
localIP := netip.MustParseAddr("100.10.0.100")
|
||||||
wgNet := &net.IPNet{
|
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
@@ -42,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = wgNet
|
|
||||||
|
|
||||||
err = manager.UpdateLocalIPs()
|
err = manager.UpdateLocalIPs()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -188,16 +183,313 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
ruleAction: fw.ActionAccept,
|
ruleAction: fw.ActionAccept,
|
||||||
shouldBeBlocked: true,
|
shouldBeBlocked: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Allow TCP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow UDP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP packet doesn't match UDP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP packet doesn't match TCP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match TCP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match UDP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow TCP traffic within port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block TCP traffic outside port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 7999,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Edge Case - Port at Range Boundary",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8100,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP Port Range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 5060,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow multiple destination ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow multiple source ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
// New drop test cases
|
||||||
|
{
|
||||||
|
name: "Drop TCP traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop UDP traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop ICMP traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolICMP,
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop all traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolALL,
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop traffic from multiple source ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop multiple destination ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop TCP traffic within port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Accept TCP traffic outside drop port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 7999,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop TCP traffic with source port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 32100,
|
||||||
|
dstPort: 80,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed rule - drop specific port but allow other ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||||
isDropped := manager.DropIncoming(packet, 0)
|
isDropped := manager.FilterInbound(packet, 0)
|
||||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
if tc.ruleAction == fw.ActionDrop {
|
||||||
|
// add general accept rule to test drop rule
|
||||||
|
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
|
||||||
|
rules, err := manager.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
fw.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, rules)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
for _, rule := range rules {
|
||||||
|
require.NoError(t, manager.DeletePeerRule(rule))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
rules, err := manager.AddPeerFiltering(
|
rules, err := manager.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
net.ParseIP(tc.ruleIP),
|
net.ParseIP(tc.ruleIP),
|
||||||
@@ -217,7 +509,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
isDropped := manager.DropIncoming(packet, 0)
|
isDropped := manager.FilterInbound(packet, 0)
|
||||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -283,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
dev := mocks.NewMockDevice(ctrl)
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
localIP, wgNet, err := net.ParseCIDR(network)
|
wgNet := netip.MustParsePrefix(network)
|
||||||
require.NoError(tb, err)
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: localIP,
|
IP: wgNet.Addr(),
|
||||||
Network: wgNet,
|
Network: wgNet,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -303,8 +594,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
require.True(tb, manager.routingEnabled.Load())
|
require.True(tb, manager.routingEnabled.Load())
|
||||||
require.False(tb, manager.nativeRouter.Load())
|
require.False(tb, manager.nativeRouter.Load())
|
||||||
@@ -321,7 +612,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
type rule struct {
|
type rule struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -347,7 +638,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -363,7 +654,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -379,7 +670,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -395,7 +686,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolUDP,
|
proto: fw.ProtocolUDP,
|
||||||
dstPort: &fw.Port{Values: []uint16{53}},
|
dstPort: &fw.Port{Values: []uint16{53}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -409,7 +700,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
@@ -424,7 +715,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -440,7 +731,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -456,7 +747,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -472,7 +763,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -488,7 +779,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -507,7 +798,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
netip.MustParsePrefix("100.10.0.0/16"),
|
netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
},
|
},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -521,7 +812,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
@@ -536,33 +827,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
shouldPass: true,
|
shouldPass: true,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "Multiple source networks with mismatched protocol",
|
|
||||||
srcIP: "172.16.0.1",
|
|
||||||
dstIP: "192.168.1.100",
|
|
||||||
// Should not match TCP rule
|
|
||||||
proto: fw.ProtocolUDP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 80,
|
|
||||||
rule: rule{
|
|
||||||
sources: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("100.10.0.0/16"),
|
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
|
||||||
},
|
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
proto: fw.ProtocolTCP,
|
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
|
||||||
action: fw.ActionAccept,
|
|
||||||
},
|
|
||||||
shouldPass: false,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "Allow multiple destination ports",
|
name: "Allow multiple destination ports",
|
||||||
srcIP: "100.10.0.1",
|
srcIP: "100.10.0.1",
|
||||||
@@ -572,7 +843,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -588,7 +859,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -604,7 +875,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
@@ -621,7 +892,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -640,7 +911,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 7999,
|
dstPort: 7999,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -659,7 +930,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{
|
srcPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -678,7 +949,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{
|
srcPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -700,7 +971,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8100,
|
dstPort: 8100,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -719,7 +990,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 5060,
|
dstPort: 5060,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolUDP,
|
proto: fw.ProtocolUDP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -738,7 +1009,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -757,7 +1028,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -773,7 +1044,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
},
|
},
|
||||||
@@ -791,17 +1062,158 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
netip.MustParsePrefix("100.10.0.0/16"),
|
netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
},
|
},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
},
|
},
|
||||||
shouldPass: false,
|
shouldPass: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "Drop empty destination set",
|
||||||
|
srcIP: "172.16.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
},
|
||||||
|
dest: fw.Network{Set: fw.Set{}},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Accept TCP traffic outside drop port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 7999,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
action: fw.ActionDrop,
|
||||||
|
},
|
||||||
|
shouldPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow TCP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow UDP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP packet doesn't match UDP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP packet doesn't match TCP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match TCP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match UDP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if tc.rule.action == fw.ActionDrop {
|
||||||
|
// add general accept rule to test drop rule
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
tc.rule.sources,
|
tc.rule.sources,
|
||||||
@@ -821,7 +1233,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
srcIP := netip.MustParseAddr(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := netip.MustParseAddr(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
|
|
||||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
// testing routeACLsPass only and not FilterInbound, as routed packets are dropped after being passed
|
||||||
// to the forwarder
|
// to the forwarder
|
||||||
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
require.Equal(t, tc.shouldPass, isAllowed)
|
require.Equal(t, tc.shouldPass, isAllowed)
|
||||||
@@ -836,7 +1248,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
rules []struct {
|
rules []struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -857,7 +1269,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
name: "Drop rules take precedence over accept",
|
name: "Drop rules take precedence over accept",
|
||||||
rules: []struct {
|
rules: []struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -866,7 +1278,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Accept rule added first
|
// Accept rule added first
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80, 443}},
|
dstPort: &fw.Port{Values: []uint16{80, 443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -874,7 +1286,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Drop rule added second but should be evaluated first
|
// Drop rule added second but should be evaluated first
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -912,7 +1324,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
name: "Multiple drop rules take precedence",
|
name: "Multiple drop rules take precedence",
|
||||||
rules: []struct {
|
rules: []struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -921,14 +1333,14 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Accept all
|
// Accept all
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Drop specific port
|
// Drop specific port
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -936,7 +1348,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Drop different port
|
// Drop different port
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -1015,3 +1427,50 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouteACLSet(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
|
// Add rule that uses the set (initially empty)
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Set: set},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
|
||||||
|
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||||
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Now the packet should be allowed
|
||||||
|
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||||
|
}
|
||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
@@ -270,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -284,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
@@ -327,7 +321,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), 0) {
|
if m.filterInbound(buf.Bytes(), 0) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -395,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -457,7 +447,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test hook gets called
|
// Test hook gets called
|
||||||
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
result := manager.filterOutbound(buf.Bytes(), 0)
|
||||||
require.True(t, result)
|
require.True(t, result)
|
||||||
require.True(t, hookCalled)
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
@@ -467,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
result = manager.filterOutbound(buf.Bytes(), 0)
|
||||||
require.False(t, result)
|
require.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -508,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
@@ -568,7 +553,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Process outbound packet and verify connection tracking
|
// Process outbound packet and verify connection tracking
|
||||||
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
drop := manager.FilterOutbound(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
@@ -635,7 +620,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
drop = manager.filterInbound(inboundBuf.Bytes(), 0)
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@@ -684,7 +669,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new outbound connection for invalid tests
|
// Create a new outbound connection for invalid tests
|
||||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
drop = manager.filterOutbound(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
for _, tc := range invalidCases {
|
for _, tc := range invalidCases {
|
||||||
@@ -706,8 +691,208 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
drop = manager.filterInbound(testBuf.Bytes(), 0)
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateSetMerge(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
|
initialPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Set: set},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
|
// Update the set with initial prefixes
|
||||||
|
err = manager.UpdateSet(set, initialPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test initial prefixes work
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
dstIP1 := netip.MustParseAddr("10.0.0.100")
|
||||||
|
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||||
|
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||||
|
|
||||||
|
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||||
|
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||||
|
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
|
||||||
|
|
||||||
|
newPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, newPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check that all original prefixes are still included
|
||||||
|
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||||
|
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||||
|
|
||||||
|
// Check that new prefixes are included
|
||||||
|
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||||
|
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||||
|
|
||||||
|
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||||
|
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||||
|
|
||||||
|
// Verify the rule has all prefixes
|
||||||
|
manager.mutex.RLock()
|
||||||
|
foundRule := false
|
||||||
|
for _, r := range manager.routeRules {
|
||||||
|
if r.id == rule.ID() {
|
||||||
|
foundRule = true
|
||||||
|
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
|
||||||
|
"Rule should have all prefixes merged")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
require.True(t, foundRule, "Rule should be found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateSetDeduplication(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Set: set},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
|
initialPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, initialPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check the internal state for deduplication
|
||||||
|
manager.mutex.RLock()
|
||||||
|
foundRule := false
|
||||||
|
for _, r := range manager.routeRules {
|
||||||
|
if r.id == rule.ID() {
|
||||||
|
foundRule = true
|
||||||
|
// Should have deduplicated to 2 prefixes
|
||||||
|
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
|
||||||
|
|
||||||
|
// Check the prefixes are correct
|
||||||
|
expectedPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
}
|
||||||
|
for i, prefix := range expectedPrefixes {
|
||||||
|
require.True(t, r.destinations[i] == prefix,
|
||||||
|
"Prefix should match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
require.True(t, foundRule, "Rule should be found")
|
||||||
|
|
||||||
|
// Test with overlapping prefixes of different sizes
|
||||||
|
overlappingPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/16"), // More general
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"), // More general
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, overlappingPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check that all prefixes are included (no deduplication of overlapping prefixes)
|
||||||
|
manager.mutex.RLock()
|
||||||
|
for _, r := range manager.routeRules {
|
||||||
|
if r.id == rule.ID() {
|
||||||
|
// Should have all 4 prefixes (2 original + 2 new more general ones)
|
||||||
|
require.Len(t, r.destinations, 4,
|
||||||
|
"Overlapping prefixes should not be deduplicated")
|
||||||
|
|
||||||
|
// Verify they're sorted correctly (more specific prefixes should come first)
|
||||||
|
prefixes := make([]string, 0, len(r.destinations))
|
||||||
|
for _, p := range r.destinations {
|
||||||
|
prefixes = append(prefixes, p.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check sorted order
|
||||||
|
require.Equal(t, []string{
|
||||||
|
"10.0.0.0/16",
|
||||||
|
"10.0.0.0/24",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"192.168.1.0/24",
|
||||||
|
}, prefixes, "Prefixes should be sorted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Test functionality with all prefixes
|
||||||
|
testCases := []struct {
|
||||||
|
dstIP netip.Addr
|
||||||
|
expected bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
|
||||||
|
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
|
||||||
|
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
|
||||||
|
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
|
||||||
|
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
for _, tc := range testCases {
|
||||||
|
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -86,5 +86,5 @@ type epID stack.TransportEndpointID
|
|||||||
|
|
||||||
func (i epID) String() string {
|
func (i epID) String() string {
|
||||||
// src and remote is swapped
|
// src and remote is swapped
|
||||||
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
return fmt.Sprintf("%s:%d → %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
@@ -17,6 +19,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
@@ -29,14 +32,16 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
|
// ruleIdMap is used to store the rule ID for a given connection
|
||||||
|
ruleIdMap sync.Map
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ip net.IP
|
ip tcpip.Address
|
||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ones, _ := iface.Address().Network.Mask.Size()
|
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
Protocol: ipv4.ProtocolNumber,
|
Protocol: ipv4.ProtocolNumber,
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
PrefixLen: ones,
|
PrefixLen: iface.Address().Network.Bits(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
ip: iface.Address().IP,
|
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
}
|
}
|
||||||
|
|
||||||
receiveWindow := defaultReceiveWindow
|
receiveWindow := defaultReceiveWindow
|
||||||
@@ -162,8 +166,39 @@ func (f *Forwarder) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
if f.netstack && f.ip.Equal(addr) {
|
||||||
return net.IPv4(127, 0, 0, 1)
|
return net.IPv4(127, 0, 0, 1)
|
||||||
}
|
}
|
||||||
return addr.AsSlice()
|
return addr.AsSlice()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||||
|
key := buildKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
f.ruleIdMap.LoadOrStore(key, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
|
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||||
|
return value.([]byte), true
|
||||||
|
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||||
|
return value.([]byte), true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||||
|
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
|
||||||
|
return conntrack.ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
}
|
}
|
||||||
|
|
||||||
flowID := uuid.New()
|
flowID := uuid.New()
|
||||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
// TODO: support non-root
|
// TODO: support non-root
|
||||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||||
|
|
||||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
payload := fullPacket.AsSlice()
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
// For Echo Requests, send and handle response
|
// For Echo Requests, send and handle response
|
||||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||||
f.handleEchoResponse(icmpHdr, conn, id)
|
rxBytes := pkt.Size()
|
||||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||||
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||||
return
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, f.endpoint.mtu)
|
response := make([]byte, f.endpoint.mtu)
|
||||||
n, _, err := conn.ReadFrom(response)
|
n, _, err := conn.ReadFrom(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !isTimeout(err) {
|
if !isTimeout(err) {
|
||||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
@@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
|
|||||||
fullPacket = append(fullPacket, response[:n]...)
|
fullPacket = append(fullPacket, response[:n]...)
|
||||||
|
|
||||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
|
||||||
|
|
||||||
return
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
|
return len(fullPacket)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendICMPEvent stores flow events for ICMP packets
|
// sendICMPEvent stores flow events for ICMP packets
|
||||||
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
|
||||||
f.flowLogger.StoreEvent(nftypes.EventFields{
|
var rxPackets, txPackets uint64
|
||||||
|
if rxBytes > 0 {
|
||||||
|
rxPackets = 1
|
||||||
|
}
|
||||||
|
if txBytes > 0 {
|
||||||
|
txPackets = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||||
|
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||||
|
|
||||||
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.ICMP,
|
Protocol: nftypes.ICMP,
|
||||||
// TODO: handle ipv6
|
// TODO: handle ipv6
|
||||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
SourceIP: srcIp,
|
||||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
DestIP: dstIp,
|
||||||
ICMPType: icmpType,
|
ICMPType: icmpType,
|
||||||
ICMPCode: icmpCode,
|
ICMPCode: icmpCode,
|
||||||
|
|
||||||
// TODO: get packets/bytes
|
RxBytes: rxBytes,
|
||||||
})
|
TxBytes: txBytes,
|
||||||
|
RxPackets: rxPackets,
|
||||||
|
TxPackets: txPackets,
|
||||||
|
}
|
||||||
|
|
||||||
|
if typ == nftypes.TypeStart {
|
||||||
|
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||||
|
fields.RuleID = ruleId
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
@@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
|
|
||||||
flowID := uuid.New()
|
flowID := uuid.New()
|
||||||
|
|
||||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||||
var success bool
|
var success bool
|
||||||
defer func() {
|
defer func() {
|
||||||
if !success {
|
if !success {
|
||||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||||
defer func() {
|
|
||||||
if err := inConn.Close(); err != nil {
|
|
||||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
|
||||||
}
|
|
||||||
if err := outConn.Close(); err != nil {
|
|
||||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
|
||||||
}
|
|
||||||
ep.Close()
|
|
||||||
|
|
||||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Create context for managing the proxy goroutines
|
|
||||||
ctx, cancel := context.WithCancel(f.ctx)
|
ctx, cancel := context.WithCancel(f.ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(outConn, inConn)
|
<-ctx.Done()
|
||||||
errChan <- err
|
// Close connections and endpoint.
|
||||||
}()
|
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
go func() {
|
}
|
||||||
_, err := io.Copy(inConn, outConn)
|
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||||
errChan <- err
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
}()
|
}
|
||||||
|
|
||||||
select {
|
ep.Close()
|
||||||
case <-ctx.Done():
|
}()
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
|
||||||
return
|
var wg sync.WaitGroup
|
||||||
case err := <-errChan:
|
wg.Add(2)
|
||||||
if err != nil && !isClosedError(err) {
|
|
||||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
var (
|
||||||
|
bytesFromInToOut int64 // bytes from client to server (tx for client)
|
||||||
|
bytesFromOutToIn int64 // bytes from server to client (rx for client)
|
||||||
|
errInToOut error
|
||||||
|
errOutToIn error
|
||||||
|
)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
|
||||||
|
cancel()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
|
||||||
|
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
|
||||||
|
cancel()
|
||||||
|
wg.Done()
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if errInToOut != nil {
|
||||||
|
if !isClosedError(errInToOut) {
|
||||||
|
f.logger.Error("proxyTCP: copy error (in → out) for %s: %v", epID(id), errInToOut)
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
if errOutToIn != nil {
|
||||||
|
if !isClosedError(errOutToIn) {
|
||||||
|
f.logger.Error("proxyTCP: copy error (out → in) for %s: %v", epID(id), errOutToIn)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var rxPackets, txPackets uint64
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
rxPackets = tcpStats.SegmentsSent.Value()
|
||||||
|
txPackets = tcpStats.SegmentsReceived.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||||
|
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||||
|
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.TCP,
|
Protocol: nftypes.TCP,
|
||||||
// TODO: handle ipv6
|
// TODO: handle ipv6
|
||||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
SourceIP: srcIp,
|
||||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
DestIP: dstIp,
|
||||||
SourcePort: id.RemotePort,
|
SourcePort: id.RemotePort,
|
||||||
DestPort: id.LocalPort,
|
DestPort: id.LocalPort,
|
||||||
|
RxBytes: rxBytes,
|
||||||
|
TxBytes: txBytes,
|
||||||
|
RxPackets: rxPackets,
|
||||||
|
TxPackets: txPackets,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ep != nil {
|
if typ == nftypes.TypeStart {
|
||||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||||
// fields are flipped since this is the in conn
|
fields.RuleID = ruleId
|
||||||
// TODO: get bytes
|
|
||||||
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
|
||||||
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.flowLogger.StoreEvent(fields)
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
|||||||
@@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
|
|
||||||
flowID := uuid.New()
|
flowID := uuid.New()
|
||||||
|
|
||||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||||
var success bool
|
var success bool
|
||||||
defer func() {
|
defer func() {
|
||||||
if !success {
|
if !success {
|
||||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.udpForwarder.conns[id] = pConn
|
f.udpForwarder.conns[id] = pConn
|
||||||
@@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
defer func() {
|
|
||||||
|
ctx, cancel := context.WithCancel(f.ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := pConn.conn.Close(); err != nil {
|
if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := pConn.outConn.Close(); err != nil {
|
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ep.Close()
|
ep.Close()
|
||||||
|
|
||||||
f.udpForwarder.Lock()
|
|
||||||
delete(f.udpForwarder.conns, id)
|
|
||||||
f.udpForwarder.Unlock()
|
|
||||||
|
|
||||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
var txBytes, rxBytes int64
|
||||||
|
var outboundErr, inboundErr error
|
||||||
|
|
||||||
|
// outbound->inbound: copy from pConn.conn to pConn.outConn
|
||||||
go func() {
|
go func() {
|
||||||
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
defer wg.Done()
|
||||||
|
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
|
// inbound->outbound: copy from pConn.outConn to pConn.conn
|
||||||
go func() {
|
go func() {
|
||||||
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
defer wg.Done()
|
||||||
|
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
wg.Wait()
|
||||||
case <-ctx.Done():
|
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||||
return
|
f.logger.Error("proxyUDP: copy error (outbound→inbound) for %s: %v", epID(id), outboundErr)
|
||||||
case err := <-errChan:
|
|
||||||
if err != nil && !isClosedError(err) {
|
|
||||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
|
||||||
}
|
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||||
|
f.logger.Error("proxyUDP: copy error (inbound→outbound) for %s: %v", epID(id), inboundErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var rxPackets, txPackets uint64
|
||||||
|
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
rxPackets = udpStats.PacketsSent.Value()
|
||||||
|
txPackets = udpStats.PacketsReceived.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
delete(f.udpForwarder.conns, id)
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendUDPEvent stores flow events for UDP connections
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||||
|
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||||
|
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.UDP,
|
Protocol: nftypes.UDP,
|
||||||
// TODO: handle ipv6
|
// TODO: handle ipv6
|
||||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
SourceIP: srcIp,
|
||||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
DestIP: dstIp,
|
||||||
SourcePort: id.RemotePort,
|
SourcePort: id.RemotePort,
|
||||||
DestPort: id.LocalPort,
|
DestPort: id.LocalPort,
|
||||||
|
RxBytes: rxBytes,
|
||||||
|
TxBytes: txBytes,
|
||||||
|
RxPackets: rxPackets,
|
||||||
|
TxPackets: txPackets,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ep != nil {
|
if typ == nftypes.TypeStart {
|
||||||
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||||
// fields are flipped since this is the in conn
|
fields.RuleID = ruleId
|
||||||
// TODO: get bytes
|
|
||||||
fields.RxPackets = tcpStats.PacketsSent.Value()
|
|
||||||
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.flowLogger.StoreEvent(fields)
|
f.flowLogger.StoreEvent(fields)
|
||||||
@@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
|
|||||||
return time.Since(lastSeen)
|
return time.Since(lastSeen)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
// copy reads from src and writes to dst.
|
||||||
|
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
|
||||||
bufp := bufPool.Get().(*[]byte)
|
bufp := bufPool.Get().(*[]byte)
|
||||||
defer bufPool.Put(bufp)
|
defer bufPool.Put(bufp)
|
||||||
buffer := *bufp
|
buffer := *bufp
|
||||||
|
var totalBytes int64 = 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return ctx.Err()
|
return totalBytes, ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
return fmt.Errorf("set read deadline: %w", err)
|
return totalBytes, fmt.Errorf("set read deadline: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := src.Read(buffer)
|
n, err := src.Read(buffer)
|
||||||
@@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
|
|||||||
if isTimeout(err) {
|
if isTimeout(err) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return fmt.Errorf("read from %s: %w", direction, err)
|
return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dst.Write(buffer[:n])
|
nWritten, err := dst.Write(buffer[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("write to %s: %w", direction, err)
|
return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalBytes += int64(nWritten)
|
||||||
c.updateLastSeen()
|
c.updateLastSeen()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -45,24 +45,26 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
if !ip.Is4() {
|
||||||
high := uint16(ipv4[0])
|
return
|
||||||
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
}
|
||||||
|
ipv4 := ip.AsSlice()
|
||||||
|
|
||||||
if bitmap[high] == nil {
|
high := uint16(ipv4[0])
|
||||||
bitmap[high] = &ipv4LowBitmap{}
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
}
|
|
||||||
|
|
||||||
index := low / 32
|
if bitmap[high] == nil {
|
||||||
bit := low % 32
|
bitmap[high] = &ipv4LowBitmap{}
|
||||||
bitmap[high].bitmap[index] |= 1 << bit
|
}
|
||||||
|
|
||||||
ipStr := ipv4.String()
|
index := low / 32
|
||||||
if _, exists := ipv4Set[ipStr]; !exists {
|
bit := low % 32
|
||||||
ipv4Set[ipStr] = struct{}{}
|
bitmap[high].bitmap[index] |= 1 << bit
|
||||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
|
||||||
}
|
if _, exists := ipv4Set[ip]; !exists {
|
||||||
|
ipv4Set[ip] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,12 +81,12 @@ func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
|||||||
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||||
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
@@ -102,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
log.Debugf("process IP failed: %v", err)
|
log.Debugf("process IP failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -116,8 +124,8 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
var newIPv4Bitmap [256]*ipv4LowBitmap
|
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||||
ipv4Set := make(map[string]struct{})
|
ipv4Set := make(map[netip.Addr]struct{})
|
||||||
var ipv4Addresses []string
|
var ipv4Addresses []netip.Addr
|
||||||
|
|
||||||
// 127.0.0.0/8
|
// 127.0.0.0/8
|
||||||
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||||
|
|||||||
@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range",
|
name: "Localhost range",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost standard address",
|
name: "Localhost standard address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range edge",
|
name: "Localhost range edge",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP matches",
|
name: "Local IP matches",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -68,11 +56,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP doesn't match",
|
name: "Local IP doesn't match",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
expected: false,
|
expected: false,
|
||||||
@@ -80,11 +65,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP doesn't match - addresses 32 apart",
|
name: "Local IP doesn't match - addresses 32 apart",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.33"),
|
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||||
expected: false,
|
expected: false,
|
||||||
@@ -92,11 +74,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "IPv6 address",
|
name: "IPv6 address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("fe80::1"),
|
IP: netip.MustParseAddr("fe80::1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("fe80::"),
|
|
||||||
Mask: net.CIDRMask(64, 128),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
|
|||||||
408
client/firewall/uspfilter/nat.go
Normal file
408
client/firewall/uspfilter/nat.go
Normal file
@@ -0,0 +1,408 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrIPv4Only = errors.New("only IPv4 is supported for DNAT")
|
||||||
|
|
||||||
|
func ipv4Checksum(header []byte) uint16 {
|
||||||
|
if len(header) < 20 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
var sum1, sum2 uint32
|
||||||
|
|
||||||
|
// Parallel processing - unroll and compute two sums simultaneously
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[0:2]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[2:4]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[4:6]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[6:8]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[8:10]))
|
||||||
|
// Skip checksum field at [10:12]
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[12:14]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[14:16]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(header[16:18]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(header[18:20]))
|
||||||
|
|
||||||
|
sum := sum1 + sum2
|
||||||
|
|
||||||
|
// Handle remaining bytes for headers > 20 bytes
|
||||||
|
for i := 20; i < len(header)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(header[i : i+2]))
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(header)%2 == 1 {
|
||||||
|
sum += uint32(header[len(header)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
// Optimized carry fold - single iteration handles most cases
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
if sum > 0xFFFF {
|
||||||
|
sum++
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func icmpChecksum(data []byte) uint16 {
|
||||||
|
var sum1, sum2, sum3, sum4 uint32
|
||||||
|
i := 0
|
||||||
|
|
||||||
|
// Process 16 bytes at once with 4 parallel accumulators
|
||||||
|
for i <= len(data)-16 {
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(data[i+2 : i+4]))
|
||||||
|
sum3 += uint32(binary.BigEndian.Uint16(data[i+4 : i+6]))
|
||||||
|
sum4 += uint32(binary.BigEndian.Uint16(data[i+6 : i+8]))
|
||||||
|
sum1 += uint32(binary.BigEndian.Uint16(data[i+8 : i+10]))
|
||||||
|
sum2 += uint32(binary.BigEndian.Uint16(data[i+10 : i+12]))
|
||||||
|
sum3 += uint32(binary.BigEndian.Uint16(data[i+12 : i+14]))
|
||||||
|
sum4 += uint32(binary.BigEndian.Uint16(data[i+14 : i+16]))
|
||||||
|
i += 16
|
||||||
|
}
|
||||||
|
|
||||||
|
sum := sum1 + sum2 + sum3 + sum4
|
||||||
|
|
||||||
|
// Handle remaining bytes
|
||||||
|
for i < len(data)-1 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(data[i : i+2]))
|
||||||
|
i += 2
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(data)%2 == 1 {
|
||||||
|
sum += uint32(data[len(data)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
if sum > 0xFFFF {
|
||||||
|
sum++
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
type biDNATMap struct {
|
||||||
|
forward map[netip.Addr]netip.Addr
|
||||||
|
reverse map[netip.Addr]netip.Addr
|
||||||
|
}
|
||||||
|
|
||||||
|
func newBiDNATMap() *biDNATMap {
|
||||||
|
return &biDNATMap{
|
||||||
|
forward: make(map[netip.Addr]netip.Addr),
|
||||||
|
reverse: make(map[netip.Addr]netip.Addr),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) set(original, translated netip.Addr) {
|
||||||
|
b.forward[original] = translated
|
||||||
|
b.reverse[translated] = original
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) delete(original netip.Addr) {
|
||||||
|
if translated, exists := b.forward[original]; exists {
|
||||||
|
delete(b.forward, original)
|
||||||
|
delete(b.reverse, translated)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) getTranslated(original netip.Addr) (netip.Addr, bool) {
|
||||||
|
translated, exists := b.forward[original]
|
||||||
|
return translated, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (b *biDNATMap) getOriginal(translated netip.Addr) (netip.Addr, bool) {
|
||||||
|
original, exists := b.reverse[translated]
|
||||||
|
return original, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) AddInternalDNATMapping(originalAddr, translatedAddr netip.Addr) error {
|
||||||
|
if !originalAddr.IsValid() || !translatedAddr.IsValid() {
|
||||||
|
return fmt.Errorf("invalid IP addresses")
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.localipmanager.IsLocalIP(translatedAddr) {
|
||||||
|
return fmt.Errorf("cannot map to local IP: %s", translatedAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.Lock()
|
||||||
|
defer m.dnatMutex.Unlock()
|
||||||
|
|
||||||
|
// Initialize both maps together if either is nil
|
||||||
|
if m.dnatMappings == nil || m.dnatBiMap == nil {
|
||||||
|
m.dnatMappings = make(map[netip.Addr]netip.Addr)
|
||||||
|
m.dnatBiMap = newBiDNATMap()
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMappings[originalAddr] = translatedAddr
|
||||||
|
m.dnatBiMap.set(originalAddr, translatedAddr)
|
||||||
|
|
||||||
|
if len(m.dnatMappings) == 1 {
|
||||||
|
m.dnatEnabled.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// RemoveInternalDNATMapping removes a 1:1 IP address mapping
|
||||||
|
func (m *Manager) RemoveInternalDNATMapping(originalAddr netip.Addr) error {
|
||||||
|
m.dnatMutex.Lock()
|
||||||
|
defer m.dnatMutex.Unlock()
|
||||||
|
|
||||||
|
if _, exists := m.dnatMappings[originalAddr]; !exists {
|
||||||
|
return fmt.Errorf("mapping not found for: %s", originalAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
delete(m.dnatMappings, originalAddr)
|
||||||
|
m.dnatBiMap.delete(originalAddr)
|
||||||
|
if len(m.dnatMappings) == 0 {
|
||||||
|
m.dnatEnabled.Store(false)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getDNATTranslation returns the translated address if a mapping exists
|
||||||
|
func (m *Manager) getDNATTranslation(addr netip.Addr) (netip.Addr, bool) {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return addr, false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.RLock()
|
||||||
|
translated, exists := m.dnatBiMap.getTranslated(addr)
|
||||||
|
m.dnatMutex.RUnlock()
|
||||||
|
return translated, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// findReverseDNATMapping finds original address for return traffic
|
||||||
|
func (m *Manager) findReverseDNATMapping(translatedAddr netip.Addr) (netip.Addr, bool) {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return translatedAddr, false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.dnatMutex.RLock()
|
||||||
|
original, exists := m.dnatBiMap.getOriginal(translatedAddr)
|
||||||
|
m.dnatMutex.RUnlock()
|
||||||
|
return original, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// translateOutboundDNAT applies DNAT translation to outbound packets
|
||||||
|
func (m *Manager) translateOutboundDNAT(packetData []byte, d *decoder) bool {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
dstIP := netip.AddrFrom4([4]byte{packetData[16], packetData[17], packetData[18], packetData[19]})
|
||||||
|
|
||||||
|
translatedIP, exists := m.getDNATTranslation(dstIP)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rewritePacketDestination(packetData, d, translatedIP); err != nil {
|
||||||
|
m.logger.Error("Failed to rewrite packet destination: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("DNAT: %s -> %s", dstIP, translatedIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// translateInboundReverse applies reverse DNAT to inbound return traffic
|
||||||
|
func (m *Manager) translateInboundReverse(packetData []byte, d *decoder) bool {
|
||||||
|
if !m.dnatEnabled.Load() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := netip.AddrFrom4([4]byte{packetData[12], packetData[13], packetData[14], packetData[15]})
|
||||||
|
|
||||||
|
originalIP, exists := m.findReverseDNATMapping(srcIP)
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rewritePacketSource(packetData, d, originalIP); err != nil {
|
||||||
|
m.logger.Error("Failed to rewrite packet source: %v", err)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("Reverse DNAT: %s -> %s", srcIP, originalIP)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketDestination replaces destination IP in the packet
|
||||||
|
func (m *Manager) rewritePacketDestination(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||||
|
return ErrIPv4Only
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldDst [4]byte
|
||||||
|
copy(oldDst[:], packetData[16:20])
|
||||||
|
newDst := newIP.As4()
|
||||||
|
|
||||||
|
copy(packetData[16:20], newDst[:])
|
||||||
|
|
||||||
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||||
|
return fmt.Errorf("invalid IP header length")
|
||||||
|
}
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
|
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||||
|
|
||||||
|
if len(d.decoded) > 1 {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.updateTCPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.updateUDPChecksum(packetData, ipHeaderLen, oldDst[:], newDst[:])
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// rewritePacketSource replaces the source IP address in the packet
|
||||||
|
func (m *Manager) rewritePacketSource(packetData []byte, d *decoder, newIP netip.Addr) error {
|
||||||
|
if len(packetData) < 20 || d.decoded[0] != layers.LayerTypeIPv4 || !newIP.Is4() {
|
||||||
|
return ErrIPv4Only
|
||||||
|
}
|
||||||
|
|
||||||
|
var oldSrc [4]byte
|
||||||
|
copy(oldSrc[:], packetData[12:16])
|
||||||
|
newSrc := newIP.As4()
|
||||||
|
|
||||||
|
copy(packetData[12:16], newSrc[:])
|
||||||
|
|
||||||
|
ipHeaderLen := int(d.ip4.IHL) * 4
|
||||||
|
if ipHeaderLen < 20 || ipHeaderLen > len(packetData) {
|
||||||
|
return fmt.Errorf("invalid IP header length")
|
||||||
|
}
|
||||||
|
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], 0)
|
||||||
|
ipChecksum := ipv4Checksum(packetData[:ipHeaderLen])
|
||||||
|
binary.BigEndian.PutUint16(packetData[10:12], ipChecksum)
|
||||||
|
|
||||||
|
if len(d.decoded) > 1 {
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
m.updateTCPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.updateUDPChecksum(packetData, ipHeaderLen, oldSrc[:], newSrc[:])
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.updateICMPChecksum(packetData, ipHeaderLen)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateTCPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||||
|
tcpStart := ipHeaderLen
|
||||||
|
if len(packetData) < tcpStart+18 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checksumOffset := tcpStart + 16
|
||||||
|
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||||
|
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateUDPChecksum(packetData []byte, ipHeaderLen int, oldIP, newIP []byte) {
|
||||||
|
udpStart := ipHeaderLen
|
||||||
|
if len(packetData) < udpStart+8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
checksumOffset := udpStart + 6
|
||||||
|
oldChecksum := binary.BigEndian.Uint16(packetData[checksumOffset : checksumOffset+2])
|
||||||
|
|
||||||
|
if oldChecksum == 0 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
newChecksum := incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
binary.BigEndian.PutUint16(packetData[checksumOffset:checksumOffset+2], newChecksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) updateICMPChecksum(packetData []byte, ipHeaderLen int) {
|
||||||
|
icmpStart := ipHeaderLen
|
||||||
|
if len(packetData) < icmpStart+8 {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpData := packetData[icmpStart:]
|
||||||
|
binary.BigEndian.PutUint16(icmpData[2:4], 0)
|
||||||
|
checksum := icmpChecksum(icmpData)
|
||||||
|
binary.BigEndian.PutUint16(icmpData[2:4], checksum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// incrementalUpdate performs incremental checksum update per RFC 1624
|
||||||
|
func incrementalUpdate(oldChecksum uint16, oldBytes, newBytes []byte) uint16 {
|
||||||
|
sum := uint32(^oldChecksum)
|
||||||
|
|
||||||
|
// Fast path for IPv4 addresses (4 bytes) - most common case
|
||||||
|
if len(oldBytes) == 4 && len(newBytes) == 4 {
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[0:2]))
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[2:4]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[0:2]))
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[2:4]))
|
||||||
|
} else {
|
||||||
|
// Fallback for other lengths
|
||||||
|
for i := 0; i < len(oldBytes)-1; i += 2 {
|
||||||
|
sum += uint32(^binary.BigEndian.Uint16(oldBytes[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(oldBytes)%2 == 1 {
|
||||||
|
sum += uint32(^oldBytes[len(oldBytes)-1]) << 8
|
||||||
|
}
|
||||||
|
|
||||||
|
for i := 0; i < len(newBytes)-1; i += 2 {
|
||||||
|
sum += uint32(binary.BigEndian.Uint16(newBytes[i : i+2]))
|
||||||
|
}
|
||||||
|
if len(newBytes)%2 == 1 {
|
||||||
|
sum += uint32(newBytes[len(newBytes)-1]) << 8
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
sum = (sum & 0xFFFF) + (sum >> 16)
|
||||||
|
if sum > 0xFFFF {
|
||||||
|
sum++
|
||||||
|
}
|
||||||
|
|
||||||
|
return ^uint16(sum)
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule (delegates to native firewall for port forwarding)
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return nil, errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule (delegates to native firewall)
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if m.nativeFirewall == nil {
|
||||||
|
return errNatNotSupported
|
||||||
|
}
|
||||||
|
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
416
client/firewall/uspfilter/nat_bench_test.go
Normal file
416
client/firewall/uspfilter/nat_bench_test.go
Normal file
@@ -0,0 +1,416 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BenchmarkDNATTranslation measures the performance of DNAT operations
|
||||||
|
func BenchmarkDNATTranslation(b *testing.B) {
|
||||||
|
scenarios := []struct {
|
||||||
|
name string
|
||||||
|
proto layers.IPProtocol
|
||||||
|
setupDNAT bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "tcp_with_dnat",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
setupDNAT: true,
|
||||||
|
description: "TCP packet with DNAT translation enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "tcp_without_dnat",
|
||||||
|
proto: layers.IPProtocolTCP,
|
||||||
|
setupDNAT: false,
|
||||||
|
description: "TCP packet without DNAT (baseline)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "udp_with_dnat",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
setupDNAT: true,
|
||||||
|
description: "UDP packet with DNAT translation enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "udp_without_dnat",
|
||||||
|
proto: layers.IPProtocolUDP,
|
||||||
|
setupDNAT: false,
|
||||||
|
description: "UDP packet without DNAT (baseline)",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "icmp_with_dnat",
|
||||||
|
proto: layers.IPProtocolICMPv4,
|
||||||
|
setupDNAT: true,
|
||||||
|
description: "ICMP packet with DNAT translation enabled",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "icmp_without_dnat",
|
||||||
|
proto: layers.IPProtocolICMPv4,
|
||||||
|
setupDNAT: false,
|
||||||
|
description: "ICMP packet without DNAT (baseline)",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, sc := range scenarios {
|
||||||
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Setup DNAT mapping if needed
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
|
||||||
|
if sc.setupDNAT {
|
||||||
|
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create test packets
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
outboundPacket := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||||
|
|
||||||
|
// Pre-establish connection for reverse DNAT test
|
||||||
|
if sc.setupDNAT {
|
||||||
|
manager.filterOutbound(outboundPacket, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
// Benchmark outbound DNAT translation
|
||||||
|
b.Run("outbound", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Create fresh packet each time since translation modifies it
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, originalIP, sc.proto, 12345, 80)
|
||||||
|
manager.filterOutbound(packet, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Benchmark inbound reverse DNAT translation
|
||||||
|
if sc.setupDNAT {
|
||||||
|
b.Run("inbound_reverse", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Create fresh packet each time since translation modifies it
|
||||||
|
packet := generateDNATTestPacket(b, translatedIP, srcIP, sc.proto, 80, 12345)
|
||||||
|
manager.filterInbound(packet, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNATConcurrency tests DNAT performance under concurrent load
|
||||||
|
func BenchmarkDNATConcurrency(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Setup multiple DNAT mappings
|
||||||
|
numMappings := 100
|
||||||
|
originalIPs := make([]netip.Addr, numMappings)
|
||||||
|
translatedIPs := make([]netip.Addr, numMappings)
|
||||||
|
|
||||||
|
for i := 0; i < numMappings; i++ {
|
||||||
|
originalIPs[i] = netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
translatedIPs[i] = netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
err := manager.AddInternalDNATMapping(originalIPs[i], translatedIPs[i])
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
|
||||||
|
// Pre-generate packets
|
||||||
|
outboundPackets := make([][]byte, numMappings)
|
||||||
|
inboundPackets := make([][]byte, numMappings)
|
||||||
|
for i := 0; i < numMappings; i++ {
|
||||||
|
outboundPackets[i] = generateDNATTestPacket(b, srcIP, originalIPs[i], layers.IPProtocolTCP, 12345, 80)
|
||||||
|
inboundPackets[i] = generateDNATTestPacket(b, translatedIPs[i], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||||
|
// Establish connections
|
||||||
|
manager.filterOutbound(outboundPackets[i], 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
|
||||||
|
b.Run("concurrent_outbound", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
idx := i % numMappings
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, originalIPs[idx], layers.IPProtocolTCP, 12345, 80)
|
||||||
|
manager.filterOutbound(packet, 0)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("concurrent_inbound", func(b *testing.B) {
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
idx := i % numMappings
|
||||||
|
packet := generateDNATTestPacket(b, translatedIPs[idx], srcIP, layers.IPProtocolTCP, 80, 12345)
|
||||||
|
manager.filterInbound(packet, 0)
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNATScaling tests how DNAT performance scales with number of mappings
|
||||||
|
func BenchmarkDNATScaling(b *testing.B) {
|
||||||
|
mappingCounts := []int{1, 10, 100, 1000}
|
||||||
|
|
||||||
|
for _, count := range mappingCounts {
|
||||||
|
b.Run(fmt.Sprintf("mappings_%d", count), func(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Setup DNAT mappings
|
||||||
|
for i := 0; i < count; i++ {
|
||||||
|
originalIP := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
translatedIP := netip.MustParseAddr(fmt.Sprintf("10.0.%d.%d", (i/254)+1, (i%254)+1))
|
||||||
|
err := manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(b, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test with the last mapping added (worst case for lookup)
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
lastOriginal := netip.MustParseAddr(fmt.Sprintf("192.168.%d.%d", ((count-1)/254)+1, ((count-1)%254)+1))
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, lastOriginal, layers.IPProtocolTCP, 12345, 80)
|
||||||
|
manager.filterOutbound(packet, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// generateDNATTestPacket creates a test packet for DNAT benchmarking
|
||||||
|
func generateDNATTestPacket(tb testing.TB, srcIP, dstIP netip.Addr, proto layers.IPProtocol, srcPort, dstPort uint16) []byte {
|
||||||
|
tb.Helper()
|
||||||
|
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP.AsSlice(),
|
||||||
|
DstIP: dstIP.AsSlice(),
|
||||||
|
Protocol: proto,
|
||||||
|
}
|
||||||
|
|
||||||
|
var transportLayer gopacket.SerializableLayer
|
||||||
|
switch proto {
|
||||||
|
case layers.IPProtocolTCP:
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(srcPort),
|
||||||
|
DstPort: layers.TCPPort(dstPort),
|
||||||
|
SYN: true,
|
||||||
|
}
|
||||||
|
require.NoError(tb, tcp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = tcp
|
||||||
|
case layers.IPProtocolUDP:
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
|
DstPort: layers.UDPPort(dstPort),
|
||||||
|
}
|
||||||
|
require.NoError(tb, udp.SetNetworkLayerForChecksum(ipv4))
|
||||||
|
transportLayer = udp
|
||||||
|
case layers.IPProtocolICMPv4:
|
||||||
|
icmp := &layers.ICMPv4{
|
||||||
|
TypeCode: layers.CreateICMPv4TypeCode(layers.ICMPv4TypeEchoRequest, 0),
|
||||||
|
}
|
||||||
|
transportLayer = icmp
|
||||||
|
}
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{ComputeChecksums: true, FixLengths: true}
|
||||||
|
err := gopacket.SerializeLayers(buf, opts, ipv4, transportLayer, gopacket.Payload("test"))
|
||||||
|
require.NoError(tb, err)
|
||||||
|
return buf.Bytes()
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkChecksumUpdate specifically benchmarks checksum calculation performance
|
||||||
|
func BenchmarkChecksumUpdate(b *testing.B) {
|
||||||
|
// Create test data for checksum calculations
|
||||||
|
testData := make([]byte, 64) // Typical packet size for checksum testing
|
||||||
|
for i := range testData {
|
||||||
|
testData[i] = byte(i)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("ipv4_checksum", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = ipv4Checksum(testData[:20]) // IPv4 header is typically 20 bytes
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("icmp_checksum", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = icmpChecksum(testData)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("incremental_update", func(b *testing.B) {
|
||||||
|
oldBytes := []byte{192, 168, 1, 100}
|
||||||
|
newBytes := []byte{10, 0, 0, 100}
|
||||||
|
oldChecksum := uint16(0x1234)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = incrementalUpdate(oldChecksum, oldBytes, newBytes)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDNATMemoryAllocations checks for memory allocations in DNAT operations
|
||||||
|
func BenchmarkDNATMemoryAllocations(b *testing.B) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(b, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(b, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set logger to error level to reduce noise during benchmarking
|
||||||
|
manager.SetLogLevel(log.ErrorLevel)
|
||||||
|
defer func() {
|
||||||
|
// Restore to info level after benchmark
|
||||||
|
manager.SetLogLevel(log.InfoLevel)
|
||||||
|
}()
|
||||||
|
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(b, err)
|
||||||
|
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, originalIP, layers.IPProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.ReportAllocs()
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Create fresh packet each time to isolate allocation testing
|
||||||
|
testPacket := make([]byte, len(packet))
|
||||||
|
copy(testPacket, packet)
|
||||||
|
|
||||||
|
// Parse the packet fresh each time to get a clean decoder
|
||||||
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
err = d.parser.DecodeLayers(testPacket, &d.decoded)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
manager.translateOutboundDNAT(testPacket, d)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkDirectIPExtraction tests the performance improvement of direct IP extraction
|
||||||
|
func BenchmarkDirectIPExtraction(b *testing.B) {
|
||||||
|
// Create a test packet
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
packet := generateDNATTestPacket(b, srcIP, dstIP, layers.IPProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
b.Run("direct_byte_access", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Direct extraction from packet bytes
|
||||||
|
_ = netip.AddrFrom4([4]byte{packet[16], packet[17], packet[18], packet[19]})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("decoder_extraction", func(b *testing.B) {
|
||||||
|
// Create decoder once for comparison
|
||||||
|
d := &decoder{decoded: []gopacket.LayerType{}}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
err := d.parser.DecodeLayers(packet, &d.decoded)
|
||||||
|
assert.NoError(b, err)
|
||||||
|
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// Extract using decoder (traditional method)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||||
|
_ = dst
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkChecksumOptimizations compares optimized vs standard checksum implementations
|
||||||
|
func BenchmarkChecksumOptimizations(b *testing.B) {
|
||||||
|
// Create test IPv4 header (20 bytes)
|
||||||
|
header := make([]byte, 20)
|
||||||
|
for i := range header {
|
||||||
|
header[i] = byte(i)
|
||||||
|
}
|
||||||
|
// Clear checksum field
|
||||||
|
header[10] = 0
|
||||||
|
header[11] = 0
|
||||||
|
|
||||||
|
b.Run("optimized_ipv4_checksum", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = ipv4Checksum(header)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test incremental checksum updates
|
||||||
|
oldIP := []byte{192, 168, 1, 100}
|
||||||
|
newIP := []byte{10, 0, 0, 100}
|
||||||
|
oldChecksum := uint16(0x1234)
|
||||||
|
|
||||||
|
b.Run("optimized_incremental_update", func(b *testing.B) {
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
_ = incrementalUpdate(oldChecksum, oldIP, newIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
145
client/firewall/uspfilter/nat_test.go
Normal file
145
client/firewall/uspfilter/nat_test.go
Normal file
@@ -0,0 +1,145 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
)
|
||||||
|
|
||||||
|
// TestDNATTranslationCorrectness verifies DNAT translation works correctly
|
||||||
|
func TestDNATTranslationCorrectness(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
srcIP := netip.MustParseAddr("172.16.0.1")
|
||||||
|
|
||||||
|
// Add DNAT mapping
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
protocol layers.IPProtocol
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
}{
|
||||||
|
{"TCP", layers.IPProtocolTCP, 12345, 80},
|
||||||
|
{"UDP", layers.IPProtocolUDP, 12345, 53},
|
||||||
|
{"ICMP", layers.IPProtocolICMPv4, 0, 0},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
// Test outbound DNAT translation
|
||||||
|
outboundPacket := generateDNATTestPacket(t, srcIP, originalIP, tc.protocol, tc.srcPort, tc.dstPort)
|
||||||
|
originalOutbound := make([]byte, len(outboundPacket))
|
||||||
|
copy(originalOutbound, outboundPacket)
|
||||||
|
|
||||||
|
// Process outbound packet (should translate destination)
|
||||||
|
translated := manager.translateOutboundDNAT(outboundPacket, parsePacket(t, outboundPacket))
|
||||||
|
require.True(t, translated, "Outbound packet should be translated")
|
||||||
|
|
||||||
|
// Verify destination IP was changed
|
||||||
|
dstIPAfter := netip.AddrFrom4([4]byte{outboundPacket[16], outboundPacket[17], outboundPacket[18], outboundPacket[19]})
|
||||||
|
require.Equal(t, translatedIP, dstIPAfter, "Destination IP should be translated")
|
||||||
|
|
||||||
|
// Test inbound reverse DNAT translation
|
||||||
|
inboundPacket := generateDNATTestPacket(t, translatedIP, srcIP, tc.protocol, tc.dstPort, tc.srcPort)
|
||||||
|
originalInbound := make([]byte, len(inboundPacket))
|
||||||
|
copy(originalInbound, inboundPacket)
|
||||||
|
|
||||||
|
// Process inbound packet (should reverse translate source)
|
||||||
|
reversed := manager.translateInboundReverse(inboundPacket, parsePacket(t, inboundPacket))
|
||||||
|
require.True(t, reversed, "Inbound packet should be reverse translated")
|
||||||
|
|
||||||
|
// Verify source IP was changed back to original
|
||||||
|
srcIPAfter := netip.AddrFrom4([4]byte{inboundPacket[12], inboundPacket[13], inboundPacket[14], inboundPacket[15]})
|
||||||
|
require.Equal(t, originalIP, srcIPAfter, "Source IP should be reverse translated")
|
||||||
|
|
||||||
|
// Test that checksums are recalculated correctly
|
||||||
|
if tc.protocol != layers.IPProtocolICMPv4 {
|
||||||
|
// For TCP/UDP, verify the transport checksum was updated
|
||||||
|
require.NotEqual(t, originalOutbound, outboundPacket, "Outbound packet should be modified")
|
||||||
|
require.NotEqual(t, originalInbound, inboundPacket, "Inbound packet should be modified")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// parsePacket helper to create a decoder for testing
|
||||||
|
func parsePacket(t testing.TB, packetData []byte) *decoder {
|
||||||
|
t.Helper()
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
|
||||||
|
err := d.parser.DecodeLayers(packetData, &d.decoded)
|
||||||
|
require.NoError(t, err)
|
||||||
|
return d
|
||||||
|
}
|
||||||
|
|
||||||
|
// TestDNATMappingManagement tests adding/removing DNAT mappings
|
||||||
|
func TestDNATMappingManagement(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
originalIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
translatedIP := netip.MustParseAddr("10.0.0.100")
|
||||||
|
|
||||||
|
// Test adding mapping
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, translatedIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify mapping exists
|
||||||
|
result, exists := manager.getDNATTranslation(originalIP)
|
||||||
|
require.True(t, exists)
|
||||||
|
require.Equal(t, translatedIP, result)
|
||||||
|
|
||||||
|
// Test reverse lookup
|
||||||
|
reverseResult, exists := manager.findReverseDNATMapping(translatedIP)
|
||||||
|
require.True(t, exists)
|
||||||
|
require.Equal(t, originalIP, reverseResult)
|
||||||
|
|
||||||
|
// Test removing mapping
|
||||||
|
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify mapping no longer exists
|
||||||
|
_, exists = manager.getDNATTranslation(originalIP)
|
||||||
|
require.False(t, exists)
|
||||||
|
|
||||||
|
_, exists = manager.findReverseDNATMapping(translatedIP)
|
||||||
|
require.False(t, exists)
|
||||||
|
|
||||||
|
// Test error cases
|
||||||
|
err = manager.AddInternalDNATMapping(netip.Addr{}, translatedIP)
|
||||||
|
require.Error(t, err, "Should reject invalid original IP")
|
||||||
|
|
||||||
|
err = manager.AddInternalDNATMapping(originalIP, netip.Addr{})
|
||||||
|
require.Error(t, err, "Should reject invalid translated IP")
|
||||||
|
|
||||||
|
err = manager.RemoveInternalDNATMapping(originalIP)
|
||||||
|
require.Error(t, err, "Should error when removing non-existent mapping")
|
||||||
|
}
|
||||||
@@ -29,14 +29,15 @@ func (r *PeerRule) ID() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RouteRule struct {
|
type RouteRule struct {
|
||||||
id string
|
id string
|
||||||
mgmtId []byte
|
mgmtId []byte
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
destination netip.Prefix
|
dstSet firewall.Set
|
||||||
proto firewall.Protocol
|
destinations []netip.Prefix
|
||||||
srcPort *firewall.Port
|
proto firewall.Protocol
|
||||||
dstPort *firewall.Port
|
srcPort *firewall.Port
|
||||||
action firewall.Action
|
dstPort *firewall.Port
|
||||||
|
action firewall.Action
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
|
|||||||
@@ -401,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
|||||||
|
|
||||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
// will create or update the connection state
|
// will create or update the connection state
|
||||||
dropped := m.processOutgoingHooks(packetData, 0)
|
dropped := m.filterOutbound(packetData, 0)
|
||||||
if dropped {
|
if dropped {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -198,12 +195,12 @@ func TestTracePacket(t *testing.T) {
|
|||||||
m.forwarder.Store(&forwarder.Forwarder{})
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
@@ -222,12 +219,12 @@ func TestTracePacket(t *testing.T) {
|
|||||||
m.nativeRouter.Store(false)
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
@@ -245,7 +242,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
m.nativeRouter.Store(true)
|
m.nativeRouter.Store(true)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
@@ -263,7 +260,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
@@ -425,8 +422,8 @@ func TestTracePacket(t *testing.T) {
|
|||||||
|
|
||||||
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||||
"100.10.0.100 should be recognized as a local IP")
|
"100.10.0.100 should be recognized as a local IP")
|
||||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
|
||||||
"172.17.0.2 should not be recognized as a local IP")
|
"192.168.17.2 should not be recognized as a local IP")
|
||||||
|
|
||||||
pb := tc.packetBuilder()
|
pb := tc.packetBuilder()
|
||||||
|
|
||||||
|
|||||||
94
client/iface/bind/activity.go
Normal file
94
client/iface/bind/activity.go
Normal file
@@ -0,0 +1,94 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/monotime"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
saveFrequency = int64(5 * time.Second)
|
||||||
|
)
|
||||||
|
|
||||||
|
type PeerRecord struct {
|
||||||
|
Address netip.AddrPort
|
||||||
|
LastActivity atomic.Int64 // UnixNano timestamp
|
||||||
|
}
|
||||||
|
|
||||||
|
type ActivityRecorder struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
peers map[string]*PeerRecord // publicKey to PeerRecord map
|
||||||
|
addrToPeer map[netip.AddrPort]*PeerRecord // address to PeerRecord map
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewActivityRecorder() *ActivityRecorder {
|
||||||
|
return &ActivityRecorder{
|
||||||
|
peers: make(map[string]*PeerRecord),
|
||||||
|
addrToPeer: make(map[netip.AddrPort]*PeerRecord),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLastActivities returns a snapshot of peer last activity
|
||||||
|
func (r *ActivityRecorder) GetLastActivities() map[string]time.Time {
|
||||||
|
r.mu.RLock()
|
||||||
|
defer r.mu.RUnlock()
|
||||||
|
|
||||||
|
activities := make(map[string]time.Time, len(r.peers))
|
||||||
|
for key, record := range r.peers {
|
||||||
|
unixNano := record.LastActivity.Load()
|
||||||
|
activities[key] = time.Unix(0, unixNano)
|
||||||
|
}
|
||||||
|
return activities
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpsertAddress adds or updates the address for a publicKey
|
||||||
|
func (r *ActivityRecorder) UpsertAddress(publicKey string, address netip.AddrPort) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
|
||||||
|
if pr, exists := r.peers[publicKey]; exists {
|
||||||
|
delete(r.addrToPeer, pr.Address)
|
||||||
|
pr.Address = address
|
||||||
|
} else {
|
||||||
|
record := &PeerRecord{
|
||||||
|
Address: address,
|
||||||
|
}
|
||||||
|
record.LastActivity.Store(monotime.Now())
|
||||||
|
r.peers[publicKey] = record
|
||||||
|
}
|
||||||
|
|
||||||
|
r.addrToPeer[address] = r.peers[publicKey]
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *ActivityRecorder) Remove(publicKey string) {
|
||||||
|
r.mu.Lock()
|
||||||
|
defer r.mu.Unlock()
|
||||||
|
if record, exists := r.peers[publicKey]; exists {
|
||||||
|
delete(r.addrToPeer, record.Address)
|
||||||
|
delete(r.peers, publicKey)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// record updates LastActivity for the given address using atomic store
|
||||||
|
func (r *ActivityRecorder) record(address netip.AddrPort) {
|
||||||
|
r.mu.RLock()
|
||||||
|
record, ok := r.addrToPeer[address]
|
||||||
|
r.mu.RUnlock()
|
||||||
|
if !ok {
|
||||||
|
log.Warnf("could not find record for address %s", address)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
now := monotime.Now()
|
||||||
|
last := record.LastActivity.Load()
|
||||||
|
if now-last < saveFrequency {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
_ = record.LastActivity.CompareAndSwap(last, now)
|
||||||
|
}
|
||||||
27
client/iface/bind/activity_test.go
Normal file
27
client/iface/bind/activity_test.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestActivityRecorder_GetLastActivities(t *testing.T) {
|
||||||
|
peer := "peer1"
|
||||||
|
ar := NewActivityRecorder()
|
||||||
|
ar.UpsertAddress("peer1", netip.MustParseAddrPort("192.168.0.5:51820"))
|
||||||
|
activities := ar.GetLastActivities()
|
||||||
|
|
||||||
|
p, ok := activities[peer]
|
||||||
|
if !ok {
|
||||||
|
t.Fatalf("Expected activity for peer %s, but got none", peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.IsZero() {
|
||||||
|
t.Fatalf("Expected activity for peer %s, but got zero", peer)
|
||||||
|
}
|
||||||
|
|
||||||
|
if p.Before(time.Now().Add(-2 * time.Minute)) {
|
||||||
|
t.Fatalf("Expected activity for peer %s to be recent, but got %v", peer, p)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -1,6 +1,7 @@
|
|||||||
package bind
|
package bind
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
@@ -51,22 +52,24 @@ type ICEBind struct {
|
|||||||
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
closedChanMu sync.RWMutex // protect the closeChan recreation from reading from it.
|
||||||
closed bool
|
closed bool
|
||||||
|
|
||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *UniversalUDPMuxDefault
|
udpMux *UniversalUDPMuxDefault
|
||||||
address wgaddr.Address
|
address wgaddr.Address
|
||||||
|
activityRecorder *ActivityRecorder
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||||
ib := &ICEBind{
|
ib := &ICEBind{
|
||||||
StdNetBind: b,
|
StdNetBind: b,
|
||||||
RecvChan: make(chan RecvMessage, 1),
|
RecvChan: make(chan RecvMessage, 1),
|
||||||
transportNet: transportNet,
|
transportNet: transportNet,
|
||||||
filterFn: filterFn,
|
filterFn: filterFn,
|
||||||
endpoints: make(map[netip.Addr]net.Conn),
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
closedChan: make(chan struct{}),
|
closedChan: make(chan struct{}),
|
||||||
closed: true,
|
closed: true,
|
||||||
address: address,
|
address: address,
|
||||||
|
activityRecorder: NewActivityRecorder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := receiverCreator{
|
rc := receiverCreator{
|
||||||
@@ -100,6 +103,10 @@ func (s *ICEBind) Close() error {
|
|||||||
return s.StdNetBind.Close()
|
return s.StdNetBind.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *ICEBind) ActivityRecorder() *ActivityRecorder {
|
||||||
|
return s.activityRecorder
|
||||||
|
}
|
||||||
|
|
||||||
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
// GetICEMux returns the ICE UDPMux that was created and used by ICEBind
|
||||||
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
||||||
s.muUDPMux.Lock()
|
s.muUDPMux.Lock()
|
||||||
@@ -199,6 +206,11 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
addrPort := msg.Addr.(*net.UDPAddr).AddrPort()
|
||||||
|
|
||||||
|
if isTransportPkg(msg.Buffers, msg.N) {
|
||||||
|
s.activityRecorder.record(addrPort)
|
||||||
|
}
|
||||||
|
|
||||||
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
ep := &wgConn.StdNetEndpoint{AddrPort: addrPort} // TODO: remove allocation
|
||||||
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
wgConn.GetSrcFromControl(msg.OOB[:msg.NN], ep)
|
||||||
eps[i] = ep
|
eps[i] = ep
|
||||||
@@ -257,6 +269,13 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
|||||||
copy(buffs[0], msg.Buffer)
|
copy(buffs[0], msg.Buffer)
|
||||||
sizes[0] = len(msg.Buffer)
|
sizes[0] = len(msg.Buffer)
|
||||||
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
eps[0] = wgConn.Endpoint(msg.Endpoint)
|
||||||
|
|
||||||
|
if isTransportPkg(buffs, sizes[0]) {
|
||||||
|
if ep, ok := eps[0].(*Endpoint); ok {
|
||||||
|
c.activityRecorder.record(ep.AddrPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
return 1, nil
|
return 1, nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -272,3 +291,19 @@ func putMessages(msgs *[]ipv6.Message, msgsPool *sync.Pool) {
|
|||||||
}
|
}
|
||||||
msgsPool.Put(msgs)
|
msgsPool.Put(msgs)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func isTransportPkg(buffers [][]byte, n int) bool {
|
||||||
|
// The first buffer should contain at least 4 bytes for type
|
||||||
|
if len(buffers[0]) < 4 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// WireGuard packet type is a little-endian uint32 at start
|
||||||
|
packetType := binary.LittleEndian.Uint32(buffers[0][:4])
|
||||||
|
|
||||||
|
// Check if packetType matches known WireGuard message types
|
||||||
|
if packetType == 4 && n > 32 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|||||||
@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.address.Network.Contains(a.AsSlice()) {
|
if u.address.Network.Contains(a) {
|
||||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
}
|
}
|
||||||
|
|||||||
17
client/iface/configurer/common.go
Normal file
17
client/iface/configurer/common.go
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
package configurer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
||||||
|
ipNets := make([]net.IPNet, len(prefixes))
|
||||||
|
for i, prefix := range prefixes {
|
||||||
|
ipNets[i] = net.IPNet{
|
||||||
|
IP: prefix.Addr().AsSlice(), // Convert netip.Addr to net.IP
|
||||||
|
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return ipNets
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ package configurer
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -12,6 +13,8 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var zeroKey wgtypes.Key
|
||||||
|
|
||||||
type KernelConfigurer struct {
|
type KernelConfigurer struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
}
|
}
|
||||||
@@ -43,7 +46,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -52,7 +55,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, ke
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||||
AllowedIPs: allowedIps,
|
AllowedIPs: prefixesToIPNets(allowedIps),
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
PresharedKey: preSharedKey,
|
PresharedKey: preSharedKey,
|
||||||
@@ -89,10 +92,10 @@ func (c *KernelConfigurer) RemovePeer(peerKey string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
ipNet := net.IPNet{
|
||||||
if err != nil {
|
IP: allowedIP.Addr().AsSlice(),
|
||||||
return err
|
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
@@ -103,7 +106,7 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
UpdateOnly: true,
|
UpdateOnly: true,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
AllowedIPs: []net.IPNet{*ipNet},
|
AllowedIPs: []net.IPNet{ipNet},
|
||||||
}
|
}
|
||||||
|
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
@@ -116,10 +119,10 @@ func (c *KernelConfigurer) AddAllowedIP(peerKey string, allowedIP string) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
func (c *KernelConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
ipNet := net.IPNet{
|
||||||
if err != nil {
|
IP: allowedIP.Addr().AsSlice(),
|
||||||
return fmt.Errorf("parse allowed IP: %w", err)
|
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
@@ -187,7 +190,11 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
defer wg.Close()
|
defer func() {
|
||||||
|
if err := wg.Close(); err != nil {
|
||||||
|
log.Errorf("Failed to close wgctrl client: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
// validate if device with name exists
|
// validate if device with name exists
|
||||||
_, err = wg.Device(c.deviceName)
|
_, err = wg.Device(c.deviceName)
|
||||||
@@ -201,14 +208,75 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
|||||||
func (c *KernelConfigurer) Close() {
|
func (c *KernelConfigurer) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
|
func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||||
peer, err := c.getPeer(c.deviceName, peerKey)
|
wg, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
|
return nil, fmt.Errorf("wgctl: %w", err)
|
||||||
}
|
}
|
||||||
return WGStats{
|
defer func() {
|
||||||
LastHandshake: peer.LastHandshakeTime,
|
err = wg.Close()
|
||||||
TxBytes: peer.TransmitBytes,
|
if err != nil {
|
||||||
RxBytes: peer.ReceiveBytes,
|
log.Errorf("Got error while closing wgctl: %v", err)
|
||||||
}, nil
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wgDevice, err := wg.Device(c.deviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||||
|
}
|
||||||
|
fullStats := &Stats{
|
||||||
|
DeviceName: wgDevice.Name,
|
||||||
|
PublicKey: wgDevice.PublicKey.String(),
|
||||||
|
ListenPort: wgDevice.ListenPort,
|
||||||
|
FWMark: wgDevice.FirewallMark,
|
||||||
|
Peers: []Peer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range wgDevice.Peers {
|
||||||
|
peer := Peer{
|
||||||
|
PublicKey: p.PublicKey.String(),
|
||||||
|
AllowedIPs: p.AllowedIPs,
|
||||||
|
TxBytes: p.TransmitBytes,
|
||||||
|
RxBytes: p.ReceiveBytes,
|
||||||
|
LastHandshake: p.LastHandshakeTime,
|
||||||
|
PresharedKey: p.PresharedKey != zeroKey,
|
||||||
|
}
|
||||||
|
if p.Endpoint != nil {
|
||||||
|
peer.Endpoint = *p.Endpoint
|
||||||
|
}
|
||||||
|
fullStats.Peers = append(fullStats.Peers, peer)
|
||||||
|
}
|
||||||
|
return fullStats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||||
|
stats := make(map[string]WGStats)
|
||||||
|
wg, err := wgctrl.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("wgctl: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Got error while closing wgctl: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wgDevice, err := wg.Device(c.deviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range wgDevice.Peers {
|
||||||
|
stats[peer.PublicKey.String()] = WGStats{
|
||||||
|
LastHandshake: peer.LastHandshakeTime,
|
||||||
|
TxBytes: peer.TransmitBytes,
|
||||||
|
RxBytes: peer.ReceiveBytes,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *KernelConfigurer) LastActivities() map[string]time.Time {
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -14,22 +16,39 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/device"
|
"golang.zx2c4.com/wireguard/device"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
privateKey = "private_key"
|
||||||
|
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||||
|
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||||
|
ipcKeyTxBytes = "tx_bytes"
|
||||||
|
ipcKeyRxBytes = "rx_bytes"
|
||||||
|
allowedIP = "allowed_ip"
|
||||||
|
endpoint = "endpoint"
|
||||||
|
fwmark = "fwmark"
|
||||||
|
listenPort = "listen_port"
|
||||||
|
publicKey = "public_key"
|
||||||
|
presharedKey = "preshared_key"
|
||||||
|
)
|
||||||
|
|
||||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||||
|
|
||||||
type WGUSPConfigurer struct {
|
type WGUSPConfigurer struct {
|
||||||
device *device.Device
|
device *device.Device
|
||||||
deviceName string
|
deviceName string
|
||||||
|
activityRecorder *bind.ActivityRecorder
|
||||||
|
|
||||||
uapiListener net.Listener
|
uapiListener net.Listener
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewUSPConfigurer(device *device.Device, deviceName string) *WGUSPConfigurer {
|
func NewUSPConfigurer(device *device.Device, deviceName string, activityRecorder *bind.ActivityRecorder) *WGUSPConfigurer {
|
||||||
wgCfg := &WGUSPConfigurer{
|
wgCfg := &WGUSPConfigurer{
|
||||||
device: device,
|
device: device,
|
||||||
deviceName: deviceName,
|
deviceName: deviceName,
|
||||||
|
activityRecorder: activityRecorder,
|
||||||
}
|
}
|
||||||
wgCfg.startUAPI()
|
wgCfg.startUAPI()
|
||||||
return wgCfg
|
return wgCfg
|
||||||
@@ -52,7 +71,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -61,7 +80,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, kee
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||||
AllowedIPs: allowedIps,
|
AllowedIPs: prefixesToIPNets(allowedIps),
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
PresharedKey: preSharedKey,
|
PresharedKey: preSharedKey,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
@@ -71,7 +90,19 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, kee
|
|||||||
Peers: []wgtypes.PeerConfig{peer},
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
}
|
}
|
||||||
|
|
||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
if ipcErr := c.device.IpcSet(toWgUserspaceString(config)); ipcErr != nil {
|
||||||
|
return ipcErr
|
||||||
|
}
|
||||||
|
|
||||||
|
if endpoint != nil {
|
||||||
|
addr, err := netip.ParseAddr(endpoint.IP.String())
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to parse endpoint address: %w", err)
|
||||||
|
}
|
||||||
|
addrPort := netip.AddrPortFrom(addr, uint16(endpoint.Port))
|
||||||
|
c.activityRecorder.UpsertAddress(peerKey, addrPort)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
||||||
@@ -88,13 +119,16 @@ func (c *WGUSPConfigurer) RemovePeer(peerKey string) error {
|
|||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
Peers: []wgtypes.PeerConfig{peer},
|
Peers: []wgtypes.PeerConfig{peer},
|
||||||
}
|
}
|
||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
ipcErr := c.device.IpcSet(toWgUserspaceString(config))
|
||||||
|
|
||||||
|
c.activityRecorder.Remove(peerKey)
|
||||||
|
return ipcErr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
ipNet := net.IPNet{
|
||||||
if err != nil {
|
IP: allowedIP.Addr().AsSlice(),
|
||||||
return err
|
Mask: net.CIDRMask(allowedIP.Bits(), allowedIP.Addr().BitLen()),
|
||||||
}
|
}
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
@@ -105,7 +139,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
UpdateOnly: true,
|
UpdateOnly: true,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
AllowedIPs: []net.IPNet{*ipNet},
|
AllowedIPs: []net.IPNet{ipNet},
|
||||||
}
|
}
|
||||||
|
|
||||||
config := wgtypes.Config{
|
config := wgtypes.Config{
|
||||||
@@ -115,7 +149,7 @@ func (c *WGUSPConfigurer) AddAllowedIP(peerKey string, allowedIP string) error {
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
ipc, err := c.device.IpcGet()
|
ipc, err := c.device.IpcGet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -138,6 +172,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
|||||||
|
|
||||||
foundPeer := false
|
foundPeer := false
|
||||||
removedAllowedIP := false
|
removedAllowedIP := false
|
||||||
|
ip := allowedIP.String()
|
||||||
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
@@ -160,8 +196,8 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
|||||||
|
|
||||||
// Append the line to the output string
|
// Append the line to the output string
|
||||||
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
|
if foundPeer && strings.HasPrefix(line, "allowed_ip=") {
|
||||||
allowedIP := strings.TrimPrefix(line, "allowed_ip=")
|
allowedIPStr := strings.TrimPrefix(line, "allowed_ip=")
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIP)
|
_, ipNet, err := net.ParseCIDR(allowedIPStr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -178,6 +214,19 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
||||||
|
ipcStr, err := c.device.IpcGet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("IpcGet failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseStatus(c.deviceName, ipcStr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *WGUSPConfigurer) LastActivities() map[string]time.Time {
|
||||||
|
return c.activityRecorder.GetLastActivities()
|
||||||
|
}
|
||||||
|
|
||||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||||
func (t *WGUSPConfigurer) startUAPI() {
|
func (t *WGUSPConfigurer) startUAPI() {
|
||||||
var err error
|
var err error
|
||||||
@@ -217,91 +266,75 @@ func (t *WGUSPConfigurer) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
|
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
||||||
ipc, err := t.device.IpcGet()
|
ipc, err := t.device.IpcGet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGStats{}, fmt.Errorf("ipc get: %w", err)
|
return nil, fmt.Errorf("ipc get: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := findPeerInfo(ipc, peerKey, []string{
|
return parseTransfers(ipc)
|
||||||
"last_handshake_time_sec",
|
|
||||||
"last_handshake_time_nsec",
|
|
||||||
"tx_bytes",
|
|
||||||
"rx_bytes",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("find peer info: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
|
|
||||||
}
|
|
||||||
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
|
|
||||||
}
|
|
||||||
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return WGStats{
|
|
||||||
LastHandshake: time.Unix(sec, nsec),
|
|
||||||
TxBytes: txBytes,
|
|
||||||
RxBytes: rxBytes,
|
|
||||||
}, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
|
func parseTransfers(ipc string) (map[string]WGStats, error) {
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
stats := make(map[string]WGStats)
|
||||||
if err != nil {
|
var (
|
||||||
return nil, fmt.Errorf("parse key: %w", err)
|
currentKey string
|
||||||
}
|
currentStats WGStats
|
||||||
|
hasPeer bool
|
||||||
hexKey := hex.EncodeToString(peerKeyParsed[:])
|
)
|
||||||
|
lines := strings.Split(ipc, "\n")
|
||||||
lines := strings.Split(ipcInput, "\n")
|
|
||||||
|
|
||||||
configFound := map[string]string{}
|
|
||||||
foundPeer := false
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
// If we're within the details of the found peer and encounter another public key,
|
// If we're within the details of the found peer and encounter another public key,
|
||||||
// this means we're starting another peer's details. So, stop.
|
// this means we're starting another peer's details. So, stop.
|
||||||
if strings.HasPrefix(line, "public_key=") && foundPeer {
|
if strings.HasPrefix(line, "public_key=") {
|
||||||
break
|
peerID := strings.TrimPrefix(line, "public_key=")
|
||||||
}
|
h, err := hex.DecodeString(peerID)
|
||||||
|
if err != nil {
|
||||||
// Identify the peer with the specific public key
|
return nil, fmt.Errorf("decode peerID: %w", err)
|
||||||
if line == fmt.Sprintf("public_key=%s", hexKey) {
|
|
||||||
foundPeer = true
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, key := range searchConfigKeys {
|
|
||||||
if foundPeer && strings.HasPrefix(line, key+"=") {
|
|
||||||
v := strings.SplitN(line, "=", 2)
|
|
||||||
configFound[v[0]] = v[1]
|
|
||||||
}
|
}
|
||||||
|
currentKey = base64.StdEncoding.EncodeToString(h)
|
||||||
|
currentStats = WGStats{} // Reset stats for the new peer
|
||||||
|
hasPeer = true
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if !hasPeer {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
key := strings.SplitN(line, "=", 2)
|
||||||
|
if len(key) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
switch key[0] {
|
||||||
|
case ipcKeyLastHandshakeTimeSec:
|
||||||
|
hs, err := toLastHandshake(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
currentStats.LastHandshake = hs
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
case ipcKeyRxBytes:
|
||||||
|
rxBytes, err := toBytes(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse rx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
currentStats.RxBytes = rxBytes
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
case ipcKeyTxBytes:
|
||||||
|
TxBytes, err := toBytes(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse tx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
currentStats.TxBytes = TxBytes
|
||||||
|
stats[currentKey] = currentStats
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: use multierr
|
return stats, nil
|
||||||
for _, key := range searchConfigKeys {
|
|
||||||
if _, ok := configFound[key]; !ok {
|
|
||||||
return configFound, fmt.Errorf("config key not found: %s", key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !foundPeer {
|
|
||||||
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return configFound, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||||
@@ -355,9 +388,154 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toLastHandshake(stringVar string) (time.Time, error) {
|
||||||
|
sec, err := strconv.ParseInt(stringVar, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||||
|
}
|
||||||
|
return time.Unix(sec, 0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toBytes(s string) (int64, error) {
|
||||||
|
return strconv.ParseInt(s, 10, 64)
|
||||||
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.ControlPlaneMark
|
return nbnet.ControlPlaneMark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
|
||||||
|
// Decode hex string to bytes
|
||||||
|
keyBytes, err := hex.DecodeString(hexKey)
|
||||||
|
if err != nil {
|
||||||
|
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
|
||||||
|
if len(keyBytes) != 32 {
|
||||||
|
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to wgtypes.Key
|
||||||
|
var key wgtypes.Key
|
||||||
|
copy(key[:], keyBytes)
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||||
|
stats := &Stats{DeviceName: deviceName}
|
||||||
|
var currentPeer *Peer
|
||||||
|
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(line, "=", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := parts[0]
|
||||||
|
val := parts[1]
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case privateKey:
|
||||||
|
key, err := hexToWireguardKey(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse private key: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats.PublicKey = key.PublicKey().String()
|
||||||
|
case publicKey:
|
||||||
|
// Save previous peer
|
||||||
|
if currentPeer != nil {
|
||||||
|
stats.Peers = append(stats.Peers, *currentPeer)
|
||||||
|
}
|
||||||
|
key, err := hexToWireguardKey(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse public key: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer = &Peer{
|
||||||
|
PublicKey: key.String(),
|
||||||
|
}
|
||||||
|
case listenPort:
|
||||||
|
if port, err := strconv.Atoi(val); err == nil {
|
||||||
|
stats.ListenPort = port
|
||||||
|
}
|
||||||
|
case fwmark:
|
||||||
|
if fwmark, err := strconv.Atoi(val); err == nil {
|
||||||
|
stats.FWMark = fwmark
|
||||||
|
}
|
||||||
|
case endpoint:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse endpoint: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse endpoint port: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.Endpoint = net.UDPAddr{
|
||||||
|
IP: net.ParseIP(host),
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
case allowedIP:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, ipnet, err := net.ParseCIDR(val)
|
||||||
|
if err == nil {
|
||||||
|
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
|
||||||
|
}
|
||||||
|
case ipcKeyTxBytes:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rxBytes, err := toBytes(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.TxBytes = rxBytes
|
||||||
|
case ipcKeyRxBytes:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rxBytes, err := toBytes(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.RxBytes = rxBytes
|
||||||
|
|
||||||
|
case ipcKeyLastHandshakeTimeSec:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := toLastHandshake(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.LastHandshake = ts
|
||||||
|
case presharedKey:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if val != "" {
|
||||||
|
currentPeer.PresharedKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currentPeer != nil {
|
||||||
|
stats.Peers = append(stats.Peers, *currentPeer)
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,10 +2,8 @@ package configurer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@@ -34,58 +32,35 @@ errno=0
|
|||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
func Test_findPeerInfo(t *testing.T) {
|
func Test_parseTransfers(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
peerKey string
|
peerKey string
|
||||||
searchKeys []string
|
want WGStats
|
||||||
want map[string]string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "single",
|
name: "single",
|
||||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
|
||||||
searchKeys: []string{"tx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 0,
|
||||||
"tx_bytes": "38333",
|
RxBytes: 0,
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple",
|
name: "multiple",
|
||||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 38333,
|
||||||
"tx_bytes": "38333",
|
RxBytes: 2224,
|
||||||
"rx_bytes": "2224",
|
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "lastpeer",
|
name: "lastpeer",
|
||||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 1212111,
|
||||||
"tx_bytes": "1212111",
|
RxBytes: 1929999999,
|
||||||
"rx_bytes": "1929999999",
|
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "peer not found",
|
|
||||||
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
|
|
||||||
searchKeys: nil,
|
|
||||||
want: nil,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "key not found",
|
|
||||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
|
||||||
searchKeys: []string{"tx_bytes", "unknown_key"},
|
|
||||||
want: map[string]string{
|
|
||||||
"tx_bytes": "1212111",
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
|
|||||||
key, err := wgtypes.NewKey(res)
|
key, err := wgtypes.NewKey(res)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
|
stats, err := parseTransfers(ipcFixture)
|
||||||
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
|
if err != nil {
|
||||||
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
|
require.NoError(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, ok := stats[key.String()]
|
||||||
|
if !ok {
|
||||||
|
require.True(t, ok)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, tt.want, stat)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
24
client/iface/configurer/wgshow.go
Normal file
24
client/iface/configurer/wgshow.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package configurer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
PublicKey string
|
||||||
|
Endpoint net.UDPAddr
|
||||||
|
AllowedIPs []net.IPNet
|
||||||
|
TxBytes int64
|
||||||
|
RxBytes int64
|
||||||
|
LastHandshake time.Time
|
||||||
|
PresharedKey bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Stats struct {
|
||||||
|
DeviceName string
|
||||||
|
PublicKey string
|
||||||
|
ListenPort int
|
||||||
|
FWMark int
|
||||||
|
Peers []Peer
|
||||||
|
}
|
||||||
@@ -24,6 +24,7 @@ type WGTunDevice struct {
|
|||||||
mtu int
|
mtu int
|
||||||
iceBind *bind.ICEBind
|
iceBind *bind.ICEBind
|
||||||
tunAdapter TunAdapter
|
tunAdapter TunAdapter
|
||||||
|
disableDNS bool
|
||||||
|
|
||||||
name string
|
name string
|
||||||
device *device.Device
|
device *device.Device
|
||||||
@@ -32,7 +33,7 @@ type WGTunDevice struct {
|
|||||||
configurer WGConfigurer
|
configurer WGConfigurer
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter) *WGTunDevice {
|
func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind *bind.ICEBind, tunAdapter TunAdapter, disableDNS bool) *WGTunDevice {
|
||||||
return &WGTunDevice{
|
return &WGTunDevice{
|
||||||
address: address,
|
address: address,
|
||||||
port: port,
|
port: port,
|
||||||
@@ -40,6 +41,7 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
|||||||
mtu: mtu,
|
mtu: mtu,
|
||||||
iceBind: iceBind,
|
iceBind: iceBind,
|
||||||
tunAdapter: tunAdapter,
|
tunAdapter: tunAdapter,
|
||||||
|
disableDNS: disableDNS,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -49,6 +51,13 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
routesString := routesToString(routes)
|
routesString := routesToString(routes)
|
||||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
searchDomainsToString := searchDomainsToString(searchDomains)
|
||||||
|
|
||||||
|
// Skip DNS configuration when DisableDNS is enabled
|
||||||
|
if t.disableDNS {
|
||||||
|
log.Info("DNS is disabled, skipping DNS and search domain configuration")
|
||||||
|
dns = ""
|
||||||
|
searchDomainsToString = ""
|
||||||
|
}
|
||||||
|
|
||||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to create Android interface: %s", err)
|
log.Errorf("failed to create Android interface: %s", err)
|
||||||
@@ -70,7 +79,7 @@ func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string
|
|||||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -61,7 +61,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -10,11 +9,11 @@ import (
|
|||||||
|
|
||||||
// PacketFilter interface for firewall abilities
|
// PacketFilter interface for firewall abilities
|
||||||
type PacketFilter interface {
|
type PacketFilter interface {
|
||||||
// DropOutgoing filter outgoing packets from host to external destinations
|
// FilterOutbound filter outgoing packets from host to external destinations
|
||||||
DropOutgoing(packetData []byte, size int) bool
|
FilterOutbound(packetData []byte, size int) bool
|
||||||
|
|
||||||
// DropIncoming filter incoming packets from external sources to host
|
// FilterInbound filter incoming packets from external sources to host
|
||||||
DropIncoming(packetData []byte, size int) bool
|
FilterInbound(packetData []byte, size int) bool
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
@@ -24,9 +23,6 @@ type PacketFilter interface {
|
|||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
SetNetwork(*net.IPNet)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteredDevice to override Read or Write of packets
|
// FilteredDevice to override Read or Write of packets
|
||||||
@@ -58,7 +54,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
if filter.FilterOutbound(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||||
n--
|
n--
|
||||||
@@ -82,7 +78,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
filteredBufs := make([][]byte, 0, len(bufs))
|
||||||
dropped := 0
|
dropped := 0
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
if !filter.FilterInbound(buf[offset:], len(buf)) {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
filteredBufs = append(filteredBufs, buf)
|
||||||
dropped++
|
dropped++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||||
|
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
filter.EXPECT().FilterInbound(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return 1, nil
|
return 1, nil
|
||||||
})
|
})
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
filter.EXPECT().FilterOutbound(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
|
|||||||
@@ -71,7 +71,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
// this helps with support for the older NetBird clients that had a hardcoded direct mode
|
||||||
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
// t.device.DisableSomeRoamingForBrokenMobileSemantics()
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
|||||||
log.Info("create nbnetstack tun interface")
|
log.Info("create nbnetstack tun interface")
|
||||||
|
|
||||||
// TODO: get from service listener runtime IP
|
// TODO: get from service listener runtime IP
|
||||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("last ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("netstack using address: %s", t.address.IP)
|
log.Debugf("netstack using address: %s", t.address.IP)
|
||||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||||
@@ -68,7 +72,7 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
|||||||
device.NewLogger(wgLogLevel(), "[netbird] "),
|
device.NewLogger(wgLogLevel(), "[netbird] "),
|
||||||
)
|
)
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
_ = tunIface.Close()
|
_ = tunIface.Close()
|
||||||
|
|||||||
@@ -64,7 +64,7 @@ func (t *USPDevice) Create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -94,7 +94,7 @@ func (t *TunDevice) Create() (WGConfigurer, error) {
|
|||||||
return nil, fmt.Errorf("error assigning ip: %s", err)
|
return nil, fmt.Errorf("error assigning ip: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
t.configurer = configurer.NewUSPConfigurer(t.device, t.name)
|
t.configurer = configurer.NewUSPConfigurer(t.device, t.name, t.iceBind.ActivityRecorder())
|
||||||
err = t.configurer.ConfigureInterface(t.key, t.port)
|
err = t.configurer.ConfigureInterface(t.key, t.port)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.device.Close()
|
t.device.Close()
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
@@ -11,10 +12,12 @@ import (
|
|||||||
|
|
||||||
type WGConfigurer interface {
|
type WGConfigurer interface {
|
||||||
ConfigureInterface(privateKey string, port int) error
|
ConfigureInterface(privateKey string, port int) error
|
||||||
UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error
|
||||||
RemovePeer(peerKey string) error
|
RemovePeer(peerKey string) error
|
||||||
AddAllowedIP(peerKey string, allowedIP string) error
|
AddAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error
|
||||||
Close()
|
Close()
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
|
FullStats() (*configurer.Stats, error)
|
||||||
|
LastActivities() map[string]time.Time
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ip := address.IP.String()
|
ip := address.IP.String()
|
||||||
mask := "0x" + address.Network.Mask.String()
|
|
||||||
|
// Convert prefix length to hex netmask
|
||||||
|
prefixLen := address.Network.Bits()
|
||||||
|
if !address.IP.Is4() {
|
||||||
|
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||||
|
}
|
||||||
|
|
||||||
|
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||||
|
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||||
|
|
||||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||||
|
|
||||||
|
|||||||
@@ -43,6 +43,7 @@ type WGIFaceOpts struct {
|
|||||||
MobileArgs *device.MobileIFaceArguments
|
MobileArgs *device.MobileIFaceArguments
|
||||||
TransportNet transport.Net
|
TransportNet transport.Net
|
||||||
FilterFn bind.FilterFn
|
FilterFn bind.FilterFn
|
||||||
|
DisableDNS bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// WGIface represents an interface instance
|
// WGIface represents an interface instance
|
||||||
@@ -111,14 +112,14 @@ func (w *WGIface) UpdateAddr(newAddr string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
// UpdatePeer updates existing Wireguard Peer or creates a new one if doesn't exist
|
||||||
// Endpoint is optional
|
// Endpoint is optional.
|
||||||
|
// If allowedIps is given it will be added to the existing ones.
|
||||||
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (w *WGIface) UpdatePeer(peerKey string, allowedIps []netip.Prefix, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
netIPNets := prefixesToIPNets(allowedIps)
|
log.Debugf("updating interface %s peer %s, endpoint %s, allowedIPs %v", w.tun.DeviceName(), peerKey, endpoint, allowedIps)
|
||||||
log.Debugf("updating interface %s peer %s, endpoint %s", w.tun.DeviceName(), peerKey, endpoint)
|
return w.configurer.UpdatePeer(peerKey, allowedIps, keepAlive, endpoint, preSharedKey)
|
||||||
return w.configurer.UpdatePeer(peerKey, netIPNets, keepAlive, endpoint, preSharedKey)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePeer removes a Wireguard Peer from the interface iface
|
// RemovePeer removes a Wireguard Peer from the interface iface
|
||||||
@@ -131,7 +132,7 @@ func (w *WGIface) RemovePeer(peerKey string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddAllowedIP adds a prefix to the allowed IPs list of peer
|
// AddAllowedIP adds a prefix to the allowed IPs list of peer
|
||||||
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
func (w *WGIface) AddAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
@@ -140,7 +141,7 @@ func (w *WGIface) AddAllowedIP(peerKey string, allowedIP string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
|
// RemoveAllowedIP removes a prefix from the allowed IPs list of peer
|
||||||
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP string) error {
|
func (w *WGIface) RemoveAllowedIP(peerKey string, allowedIP netip.Prefix) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
@@ -185,7 +186,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.filter = filter
|
w.filter = filter
|
||||||
w.filter.SetNetwork(w.tun.WgAddress().Network)
|
|
||||||
|
|
||||||
w.tun.FilteredDevice().SetFilter(filter)
|
w.tun.FilteredDevice().SetFilter(filter)
|
||||||
return nil
|
return nil
|
||||||
@@ -212,9 +212,21 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
|||||||
return w.tun.Device()
|
return w.tun.Device()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
// GetStats returns the last handshake time, rx and tx bytes
|
||||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||||
return w.configurer.GetStats(peerKey)
|
return w.configurer.GetStats()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) LastActivities() map[string]time.Time {
|
||||||
|
w.mu.Lock()
|
||||||
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
return w.configurer.LastActivities()
|
||||||
|
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||||
|
return w.configurer.FullStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WGIface) waitUntilRemoved() error {
|
func (w *WGIface) waitUntilRemoved() error {
|
||||||
@@ -251,14 +263,3 @@ func (w *WGIface) GetNet() *netstack.Net {
|
|||||||
|
|
||||||
return w.tun.GetNet()
|
return w.tun.GetNet()
|
||||||
}
|
}
|
||||||
|
|
||||||
func prefixesToIPNets(prefixes []netip.Prefix) []net.IPNet {
|
|
||||||
ipNets := make([]net.IPNet, len(prefixes))
|
|
||||||
for i, prefix := range prefixes {
|
|
||||||
ipNets[i] = net.IPNet{
|
|
||||||
IP: net.IP(prefix.Addr().AsSlice()), // Convert netip.Addr to net.IP
|
|
||||||
Mask: net.CIDRMask(prefix.Bits(), prefix.Addr().BitLen()), // Create subnet mask
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return ipNets
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func NewWGIFace(opts WGIFaceOpts) (*WGIface, error) {
|
|||||||
|
|
||||||
wgIFace := &WGIface{
|
wgIFace := &WGIface{
|
||||||
userspaceBind: true,
|
userspaceBind: true,
|
||||||
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter),
|
tun: device.NewTunDevice(wgAddress, opts.WGPort, opts.WGPrivKey, opts.MTU, iceBind, opts.MobileArgs.TunAdapter, opts.DisableDNS),
|
||||||
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
wgProxyFactory: wgproxy.NewUSPFactory(iceBind),
|
||||||
}
|
}
|
||||||
return wgIFace, nil
|
return wgIFace, nil
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
net "net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
@@ -49,32 +48,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// FilterInbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
func (m *MockPacketFilter) FilterInbound(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
ret := m.ctrl.Call(m, "FilterInbound", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// FilterInbound indicates an expected call of FilterInbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// FilterOutbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
func (m *MockPacketFilter) FilterOutbound(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
ret := m.ctrl.Call(m, "FilterOutbound", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook mocks base method.
|
// RemovePacketHook mocks base method.
|
||||||
@@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
|
||||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetwork indicates an expected call of SetNetwork.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -46,32 +46,32 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddUDPPacketHook", reflect.TypeOf((*MockPacketFilter)(nil).AddUDPPacketHook), arg0, arg1, arg2, arg3)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// FilterInbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
func (m *MockPacketFilter) FilterInbound(arg0 []byte) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
ret := m.ctrl.Call(m, "FilterInbound", arg0)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// FilterInbound indicates an expected call of FilterInbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterInbound(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterInbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterInbound), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// FilterOutbound mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
func (m *MockPacketFilter) FilterOutbound(arg0 []byte) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
ret := m.ctrl.Call(m, "FilterOutbound", arg0)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// FilterOutbound indicates an expected call of FilterOutbound.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) FilterOutbound(arg0 interface{}) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilterOutbound", reflect.TypeOf((*MockPacketFilter)(nil).FilterOutbound), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
// SetNetwork mocks base method.
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package netstack
|
package netstack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -15,8 +13,8 @@ import (
|
|||||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||||
|
|
||||||
type NetStackTun struct { //nolint:revive
|
type NetStackTun struct { //nolint:revive
|
||||||
address net.IP
|
address netip.Addr
|
||||||
dnsAddress net.IP
|
dnsAddress netip.Addr
|
||||||
mtu int
|
mtu int
|
||||||
listenAddress string
|
listenAddress string
|
||||||
|
|
||||||
@@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
|
|||||||
tundev tun.Device
|
tundev tun.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||||
return &NetStackTun{
|
return &NetStackTun{
|
||||||
address: address,
|
address: address,
|
||||||
dnsAddress: dnsAddress,
|
dnsAddress: dnsAddress,
|
||||||
@@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||||
addr, ok := netip.AddrFromSlice(t.address)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||||
[]netip.Addr{addr.Unmap()},
|
[]netip.Addr{t.address},
|
||||||
[]netip.Addr{dnsAddr.Unmap()},
|
[]netip.Addr{t.dnsAddress},
|
||||||
t.mtu)
|
t.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|||||||
@@ -2,28 +2,27 @@ package wgaddr
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Address WireGuard parsed address
|
// Address WireGuard parsed address
|
||||||
type Address struct {
|
type Address struct {
|
||||||
IP net.IP
|
IP netip.Addr
|
||||||
Network *net.IPNet
|
Network netip.Prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||||
func ParseWGAddress(address string) (Address, error) {
|
func ParseWGAddress(address string) (Address, error) {
|
||||||
ip, network, err := net.ParseCIDR(address)
|
prefix, err := netip.ParsePrefix(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Address{}, err
|
return Address{}, err
|
||||||
}
|
}
|
||||||
return Address{
|
return Address{
|
||||||
IP: ip,
|
IP: prefix.Addr().Unmap(),
|
||||||
Network: network,
|
Network: prefix.Masked(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (addr Address) String() string {
|
func (addr Address) String() string {
|
||||||
maskSize, _ := addr.Network.Mask.Size()
|
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,8 @@
|
|||||||
|
|
||||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||||
|
|
||||||
|
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
|
||||||
|
|
||||||
Unicode True
|
Unicode True
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
@@ -49,6 +51,10 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
|
!include "MUI2.nsh"
|
||||||
|
!include LogicLib.nsh
|
||||||
|
!include "nsDialogs.nsh"
|
||||||
|
|
||||||
!define MUI_ICON "${ICON}"
|
!define MUI_ICON "${ICON}"
|
||||||
!define MUI_UNICON "${ICON}"
|
!define MUI_UNICON "${ICON}"
|
||||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||||
@@ -58,9 +64,6 @@ ShowInstDetails Show
|
|||||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
!include "MUI2.nsh"
|
|
||||||
!include LogicLib.nsh
|
|
||||||
|
|
||||||
!define MUI_ABORTWARNING
|
!define MUI_ABORTWARNING
|
||||||
!define MUI_UNABORTWARNING
|
!define MUI_UNABORTWARNING
|
||||||
|
|
||||||
@@ -70,13 +73,16 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
!insertmacro MUI_PAGE_DIRECTORY
|
!insertmacro MUI_PAGE_DIRECTORY
|
||||||
|
|
||||||
; Custom page for autostart checkbox
|
|
||||||
Page custom AutostartPage AutostartPageLeave
|
Page custom AutostartPage AutostartPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_INSTFILES
|
!insertmacro MUI_PAGE_INSTFILES
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_FINISH
|
!insertmacro MUI_PAGE_FINISH
|
||||||
|
|
||||||
|
!insertmacro MUI_UNPAGE_WELCOME
|
||||||
|
|
||||||
|
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_CONFIRM
|
!insertmacro MUI_UNPAGE_CONFIRM
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_INSTFILES
|
!insertmacro MUI_UNPAGE_INSTFILES
|
||||||
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
|
|||||||
Var AutostartCheckbox
|
Var AutostartCheckbox
|
||||||
Var AutostartEnabled
|
Var AutostartEnabled
|
||||||
|
|
||||||
|
; Variables for uninstall data deletion option
|
||||||
|
Var DeleteDataCheckbox
|
||||||
|
Var DeleteDataEnabled
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
; Function to create the autostart options page
|
; Function to create the autostart options page
|
||||||
@@ -104,8 +114,8 @@ Function AutostartPage
|
|||||||
|
|
||||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||||
Pop $AutostartCheckbox
|
Pop $AutostartCheckbox
|
||||||
${NSD_Check} $AutostartCheckbox ; Default to checked
|
${NSD_Check} $AutostartCheckbox
|
||||||
StrCpy $AutostartEnabled "1" ; Default to enabled
|
StrCpy $AutostartEnabled "1"
|
||||||
|
|
||||||
nsDialogs::Show
|
nsDialogs::Show
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
@@ -115,6 +125,30 @@ Function AutostartPageLeave
|
|||||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to create the uninstall data deletion page
|
||||||
|
Function un.DeleteDataPage
|
||||||
|
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
|
||||||
|
|
||||||
|
nsDialogs::Create 1018
|
||||||
|
Pop $0
|
||||||
|
|
||||||
|
${If} $0 == error
|
||||||
|
Abort
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
|
||||||
|
Pop $DeleteDataCheckbox
|
||||||
|
${NSD_Uncheck} $DeleteDataCheckbox
|
||||||
|
StrCpy $DeleteDataEnabled "0"
|
||||||
|
|
||||||
|
nsDialogs::Show
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to handle leaving the data deletion page
|
||||||
|
Function un.DeleteDataPageLeave
|
||||||
|
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
Function GetAppFromCommand
|
Function GetAppFromCommand
|
||||||
Exch $1
|
Exch $1
|
||||||
Push $2
|
Push $2
|
||||||
@@ -176,10 +210,10 @@ ${EndIf}
|
|||||||
FunctionEnd
|
FunctionEnd
|
||||||
######################################################################
|
######################################################################
|
||||||
Section -MainProgram
|
Section -MainProgram
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
# SetOverwrite ifnewer
|
# SetOverwrite ifnewer
|
||||||
SetOutPath "$INSTDIR"
|
SetOutPath "$INSTDIR"
|
||||||
File /r "..\\dist\\netbird_windows_amd64\\"
|
File /r "..\\dist\\netbird_windows_amd64\\"
|
||||||
SectionEnd
|
SectionEnd
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
@@ -225,31 +259,58 @@ SectionEnd
|
|||||||
Section Uninstall
|
Section Uninstall
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
|
|
||||||
|
DetailPrint "Stopping Netbird service..."
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||||
|
DetailPrint "Uninstalling Netbird service..."
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||||
|
|
||||||
# kill ui client
|
DetailPrint "Terminating Netbird UI process..."
|
||||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||||
|
|
||||||
; Remove autostart registry entry
|
; Remove autostart registry entry
|
||||||
|
DetailPrint "Removing autostart registry entry if exists..."
|
||||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
|
||||||
|
; Handle data deletion based on checkbox
|
||||||
|
DetailPrint "Checking if user requested data deletion..."
|
||||||
|
${If} $DeleteDataEnabled == "1"
|
||||||
|
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
|
||||||
|
ClearErrors
|
||||||
|
RMDir /r "${NETBIRD_DATA_DIR}"
|
||||||
|
IfErrors 0 +2 ; If no errors, jump over the message
|
||||||
|
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
|
||||||
|
DetailPrint "Netbird data directory removal complete."
|
||||||
|
${Else}
|
||||||
|
DetailPrint "User did not opt to delete Netbird data."
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
# wait the service uninstall take unblock the executable
|
# wait the service uninstall take unblock the executable
|
||||||
|
DetailPrint "Waiting for service handle to be released..."
|
||||||
Sleep 3000
|
Sleep 3000
|
||||||
|
|
||||||
|
DetailPrint "Deleting application files..."
|
||||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||||
Delete "$INSTDIR\wintun.dll"
|
Delete "$INSTDIR\wintun.dll"
|
||||||
Delete "$INSTDIR\opengl32.dll"
|
Delete "$INSTDIR\opengl32.dll"
|
||||||
|
DetailPrint "Removing application directory..."
|
||||||
RmDir /r "$INSTDIR"
|
RmDir /r "$INSTDIR"
|
||||||
|
|
||||||
|
DetailPrint "Removing shortcuts..."
|
||||||
SetShellVarContext all
|
SetShellVarContext all
|
||||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||||
|
|
||||||
|
DetailPrint "Removing registry keys..."
|
||||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||||
|
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||||
|
|
||||||
|
DetailPrint "Removing application directory from PATH..."
|
||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
EnVar::DeleteValue "path" "$INSTDIR"
|
EnVar::DeleteValue "path" "$INSTDIR"
|
||||||
|
|
||||||
|
DetailPrint "Uninstallation finished."
|
||||||
SectionEnd
|
SectionEnd
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
|
|||||||
|
|
||||||
func GenerateRouteRuleKey(
|
func GenerateRouteRuleKey(
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination manager.Network,
|
||||||
proto manager.Protocol,
|
proto manager.Protocol,
|
||||||
sPort *manager.Port,
|
sPort *manager.Port,
|
||||||
dPort *manager.Port,
|
dPort *manager.Port,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
|||||||
|
|
||||||
// Manager is a ACL rules manager
|
// Manager is a ACL rules manager
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type protoMatch struct {
|
type protoMatch struct {
|
||||||
@@ -53,10 +54,15 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
|||||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||||
//
|
//
|
||||||
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
||||||
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
|
||||||
d.mutex.Lock()
|
d.mutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.mutex.Unlock()
|
||||||
|
|
||||||
|
if d.firewall == nil {
|
||||||
|
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
total := 0
|
total := 0
|
||||||
@@ -68,21 +74,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
time.Since(start), total)
|
time.Since(start), total)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if d.firewall == nil {
|
|
||||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d.applyPeerACLs(networkMap)
|
d.applyPeerACLs(networkMap)
|
||||||
|
|
||||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
|
||||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
|
||||||
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
|
||||||
log.Errorf("failed to set legacy management flag: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
|
|
||||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,16 +170,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
d.peerRulesPairs = newRulePairs
|
d.peerRulesPairs = newRulePairs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
||||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
id, err := d.applyRouteACL(rule)
|
id, err := d.applyRouteACL(rule, dynamicResolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||||
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
|
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
||||||
} else {
|
} else {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
||||||
}
|
}
|
||||||
@@ -208,7 +202,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
|
||||||
if len(rule.SourceRanges) == 0 {
|
if len(rule.SourceRanges) == 0 {
|
||||||
return "", ErrSourceRangesEmpty
|
return "", ErrSourceRangesEmpty
|
||||||
}
|
}
|
||||||
@@ -222,15 +216,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
|||||||
sources = append(sources, source)
|
sources = append(sources, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
var destination netip.Prefix
|
destination, err := determineDestination(rule, dynamicResolver, sources)
|
||||||
if rule.IsDynamic {
|
if err != nil {
|
||||||
destination = getDefault(sources[0])
|
return "", fmt.Errorf("determine destination: %w", err)
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
destination, err = netip.ParsePrefix(rule.Destination)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("parse destination: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||||
@@ -296,8 +284,10 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
if d.firewall.IsStateful() {
|
||||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
return "", nil, nil
|
||||||
|
}
|
||||||
|
// return traffic for outbound connections if firewall is stateless
|
||||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
@@ -408,11 +398,15 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
//
|
//
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
hasPortRestrictions := r.Action == mgmProto.RuleAction_DROP ||
|
||||||
if drop {
|
r.Port != "" || !portInfoEmpty(r.PortInfo)
|
||||||
|
|
||||||
|
if hasPortRestrictions {
|
||||||
|
// Don't squash rules with port restrictions
|
||||||
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = &protoMatch{
|
protocols[r.Protocol] = &protoMatch{
|
||||||
ips: map[string]int{},
|
ips: map[string]int{},
|
||||||
@@ -580,6 +574,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
|
||||||
|
var destination firewall.Network
|
||||||
|
|
||||||
|
if rule.IsDynamic {
|
||||||
|
if dynamicResolver {
|
||||||
|
if len(rule.Domains) > 0 {
|
||||||
|
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
|
||||||
|
} else {
|
||||||
|
// isDynamic is set but no domains = outdated management server
|
||||||
|
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
|
||||||
|
destination.Prefix = getDefault(sources[0])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// client resolves DNS, we (router) don't know the destination
|
||||||
|
destination.Prefix = getDefault(sources[0])
|
||||||
|
}
|
||||||
|
return destination, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix, err := netip.ParsePrefix(rule.Destination)
|
||||||
|
if err != nil {
|
||||||
|
return destination, fmt.Errorf("parse destination: %w", err)
|
||||||
|
}
|
||||||
|
destination.Prefix = prefix
|
||||||
|
return destination, nil
|
||||||
|
}
|
||||||
|
|
||||||
func getDefault(prefix netip.Prefix) netip.Prefix {
|
func getDefault(prefix netip.Prefix) netip.Prefix {
|
||||||
if prefix.Addr().Is6() {
|
if prefix.Addr().Is6() {
|
||||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
@@ -42,35 +43,31 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: network.Addr(),
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Errorf("create firewall: %v", err)
|
defer func() {
|
||||||
return
|
err = fw.Close(nil)
|
||||||
}
|
require.NoError(t, err)
|
||||||
defer func(fw manager.Manager) {
|
}()
|
||||||
_ = fw.Close(nil)
|
|
||||||
}(fw)
|
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
t.Run("apply firewall rules", func(t *testing.T) {
|
t.Run("apply firewall rules", func(t *testing.T) {
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if fw.IsStateful() {
|
||||||
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||||
return
|
} else {
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -92,14 +89,15 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
// we should have one old and one new rule in the existed rules
|
expectedRules := 2
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if fw.IsStateful() {
|
||||||
t.Errorf("firewall rules not applied")
|
expectedRules = 1 // only the inbound rule
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||||
|
|
||||||
// check that old rule was removed
|
// check that old rule was removed
|
||||||
previousCount := 0
|
previousCount := 0
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
@@ -107,26 +105,86 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
previousCount++
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if previousCount != 1 {
|
|
||||||
t.Errorf("old rule was not removed")
|
expectedPreviousCount := 0
|
||||||
|
if !fw.IsStateful() {
|
||||||
|
expectedPreviousCount = 1
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, expectedPreviousCount, previousCount)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handle default rules", func(t *testing.T) {
|
t.Run("handle default rules", func(t *testing.T) {
|
||||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = true
|
networkMap.FirewallRulesIsEmpty = true
|
||||||
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
|
acl.ApplyFiltering(networkMap, false)
|
||||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
networkMap.FirewallRulesIsEmpty = false
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
if len(acl.peerRulesPairs) != 1 {
|
|
||||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
expectedRules := 1
|
||||||
return
|
if fw.IsStateful() {
|
||||||
|
expectedRules = 1 // only inbound allow-all rule
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultManagerStateless(t *testing.T) {
|
||||||
|
// stateless currently only in userspace, so we have to disable kernel
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||||
|
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "80",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
Port: "53",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = fw.Close(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
|
// In stateless mode, we should have both inbound and outbound rules
|
||||||
|
assert.False(t, fw.IsStateful())
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,42 +250,19 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
|||||||
|
|
||||||
manager := &DefaultManager{}
|
manager := &DefaultManager{}
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
if len(rules) != 2 {
|
assert.Equal(t, 2, len(rules))
|
||||||
t.Errorf("rules should contain 2, got: %v", rules)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r := rules[0]
|
r := rules[0]
|
||||||
switch {
|
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||||
case r.PeerIP != "0.0.0.0":
|
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||||
return
|
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||||
case r.Direction != mgmProto.RuleDirection_IN:
|
|
||||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r = rules[1]
|
r = rules[1]
|
||||||
switch {
|
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||||
case r.PeerIP != "0.0.0.0":
|
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||||
return
|
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||||
case r.Direction != mgmProto.RuleDirection_OUT:
|
|
||||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||||
@@ -291,8 +326,435 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
manager := &DefaultManager{}
|
||||||
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
|
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultManagerSquashRulesWithPortRestrictions(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
rules []*mgmProto.FirewallRule
|
||||||
|
expectedCount int
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "should not squash rules with port ranges",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with port ranges should not be squashed even if they cover all peers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not squash rules with specific ports",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with specific ports should not be squashed even if they cover all peers",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not squash rules with legacy port field",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with legacy port field should not be squashed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should not squash rules with DROP action",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_DROP,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "Rules with DROP action should not be squashed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should squash rules without port restrictions",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 1,
|
||||||
|
description: "Rules without port restrictions should be squashed into a single 0.0.0.0 rule",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed rules should not squash protocol with port restrictions",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
PortInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 4,
|
||||||
|
description: "TCP should not be squashed because one rule has port restrictions",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "should squash UDP but not TCP when TCP has port restrictions",
|
||||||
|
rules: []*mgmProto.FirewallRule{
|
||||||
|
// TCP rules with port restrictions - should NOT be squashed
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "443",
|
||||||
|
},
|
||||||
|
// UDP rules without port restrictions - SHOULD be squashed
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.3",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.4",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedCount: 5, // 4 TCP rules + 1 squashed UDP rule (0.0.0.0)
|
||||||
|
description: "UDP should be squashed to 0.0.0.0 rule, but TCP should remain as individual rules due to port restrictions",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||||
|
{AllowedIps: []string{"10.93.0.1"}},
|
||||||
|
{AllowedIps: []string{"10.93.0.2"}},
|
||||||
|
{AllowedIps: []string{"10.93.0.3"}},
|
||||||
|
{AllowedIps: []string{"10.93.0.4"}},
|
||||||
|
},
|
||||||
|
FirewallRules: tt.rules,
|
||||||
|
}
|
||||||
|
|
||||||
|
manager := &DefaultManager{}
|
||||||
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
|
|
||||||
|
assert.Equal(t, tt.expectedCount, len(rules), tt.description)
|
||||||
|
|
||||||
|
// For squashed rules, verify we get the expected 0.0.0.0 rule
|
||||||
|
if tt.expectedCount == 1 {
|
||||||
|
assert.Equal(t, "0.0.0.0", rules[0].PeerIP)
|
||||||
|
assert.Equal(t, mgmProto.RuleDirection_IN, rules[0].Direction)
|
||||||
|
assert.Equal(t, mgmProto.RuleAction_ACCEPT, rules[0].Action)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestPortInfoEmpty(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
portInfo *mgmProto.PortInfo
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "nil PortInfo should be empty",
|
||||||
|
portInfo: nil,
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with zero port should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with valid port should not be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Port{
|
||||||
|
Port: 80,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with nil range should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: nil,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with zero start range should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 0,
|
||||||
|
End: 100,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with zero end range should be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 80,
|
||||||
|
End: 0,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "PortInfo with valid range should not be empty",
|
||||||
|
portInfo: &mgmProto.PortInfo{
|
||||||
|
PortSelection: &mgmProto.PortInfo_Range_{
|
||||||
|
Range: &mgmProto.PortInfo_Range{
|
||||||
|
Start: 8080,
|
||||||
|
End: 8090,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
result := portInfoEmpty(tt.portInfo)
|
||||||
|
assert.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,33 +798,29 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: network.Addr(),
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Errorf("create firewall: %v", err)
|
defer func() {
|
||||||
return
|
err = fw.Close(nil)
|
||||||
}
|
require.NoError(t, err)
|
||||||
defer func(fw manager.Manager) {
|
}()
|
||||||
_ = fw.Close(nil)
|
|
||||||
}(fw)
|
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 3 {
|
expectedRules := 3
|
||||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
if fw.IsStateful() {
|
||||||
return
|
expectedRules = 3 // 2 inbound rules + SSH rule
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||||
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
|
||||||
if runtime.GOOS == "freebsd" {
|
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -101,7 +101,12 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
}
|
}
|
||||||
if !p.providerConfig.DisablePromptLogin {
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
|
}
|
||||||
|
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||||
|
|||||||
@@ -7,15 +7,36 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client/common"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPromptLogin(t *testing.T) {
|
func TestPromptLogin(t *testing.T) {
|
||||||
|
const (
|
||||||
|
promptLogin = "prompt=login"
|
||||||
|
maxAge0 = "max_age=0"
|
||||||
|
)
|
||||||
|
|
||||||
tt := []struct {
|
tt := []struct {
|
||||||
name string
|
name string
|
||||||
prompt bool
|
loginFlag mgm.LoginFlag
|
||||||
|
disablePromptLogin bool
|
||||||
|
expect string
|
||||||
}{
|
}{
|
||||||
{"PromptLogin", true},
|
{
|
||||||
{"NoPromptLogin", false},
|
name: "Prompt login",
|
||||||
|
loginFlag: mgm.LoginFlagPrompt,
|
||||||
|
expect: promptLogin,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max age 0 login",
|
||||||
|
loginFlag: mgm.LoginFlagMaxAge0,
|
||||||
|
expect: maxAge0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Disable prompt login",
|
||||||
|
loginFlag: mgm.LoginFlagPrompt,
|
||||||
|
disablePromptLogin: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tt {
|
for _, tc := range tt {
|
||||||
@@ -28,7 +49,7 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
UseIDToken: true,
|
UseIDToken: true,
|
||||||
DisablePromptLogin: !tc.prompt,
|
LoginFlag: tc.loginFlag,
|
||||||
}
|
}
|
||||||
pkce, err := NewPKCEAuthorizationFlow(config)
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -38,11 +59,12 @@ func TestPromptLogin(t *testing.T) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to request auth info: %v", err)
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
}
|
}
|
||||||
pattern := "prompt=login"
|
|
||||||
if tc.prompt {
|
if !tc.disablePromptLogin {
|
||||||
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
||||||
} else {
|
} else {
|
||||||
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
||||||
|
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -68,12 +68,14 @@ type ConfigInput struct {
|
|||||||
DisableServerRoutes *bool
|
DisableServerRoutes *bool
|
||||||
DisableDNS *bool
|
DisableDNS *bool
|
||||||
DisableFirewall *bool
|
DisableFirewall *bool
|
||||||
|
BlockLANAccess *bool
|
||||||
BlockLANAccess *bool
|
BlockInbound *bool
|
||||||
|
|
||||||
DisableNotifications *bool
|
DisableNotifications *bool
|
||||||
|
|
||||||
DNSLabels domain.List
|
DNSLabels domain.List
|
||||||
|
|
||||||
|
LazyConnectionEnabled *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@@ -96,8 +98,8 @@ type Config struct {
|
|||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
DisableDNS bool
|
DisableDNS bool
|
||||||
DisableFirewall bool
|
DisableFirewall bool
|
||||||
|
BlockLANAccess bool
|
||||||
BlockLANAccess bool
|
BlockInbound bool
|
||||||
|
|
||||||
DisableNotifications *bool
|
DisableNotifications *bool
|
||||||
|
|
||||||
@@ -138,6 +140,8 @@ type Config struct {
|
|||||||
ClientCertKeyPath string
|
ClientCertKeyPath string
|
||||||
|
|
||||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||||
|
|
||||||
|
LazyConnectionEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
@@ -219,6 +223,8 @@ func createNewConfig(input ConfigInput) (*Config, error) {
|
|||||||
config := &Config{
|
config := &Config{
|
||||||
// defaults to false only for new (post 0.26) configurations
|
// defaults to false only for new (post 0.26) configurations
|
||||||
ServerSSHAllowed: util.False(),
|
ServerSSHAllowed: util.False(),
|
||||||
|
// default to disabling server routes on Android for security
|
||||||
|
DisableServerRoutes: runtime.GOOS == "android",
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, err := config.apply(input); err != nil {
|
if _, err := config.apply(input); err != nil {
|
||||||
@@ -313,10 +319,6 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
*input.WireguardPort, config.WgPort)
|
*input.WireguardPort, config.WgPort)
|
||||||
config.WgPort = *input.WireguardPort
|
config.WgPort = *input.WireguardPort
|
||||||
updated = true
|
updated = true
|
||||||
} else if config.WgPort == 0 {
|
|
||||||
config.WgPort = iface.DefaultWgPort
|
|
||||||
log.Infof("using default Wireguard port %d", config.WgPort)
|
|
||||||
updated = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
if input.InterfaceName != nil && *input.InterfaceName != config.WgIface {
|
||||||
@@ -412,9 +414,15 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
config.ServerSSHAllowed = input.ServerSSHAllowed
|
config.ServerSSHAllowed = input.ServerSSHAllowed
|
||||||
updated = true
|
updated = true
|
||||||
} else if config.ServerSSHAllowed == nil {
|
} else if config.ServerSSHAllowed == nil {
|
||||||
// enables SSH for configs from old versions to preserve backwards compatibility
|
if runtime.GOOS == "android" {
|
||||||
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
// default to disabled SSH on Android for security
|
||||||
config.ServerSSHAllowed = util.True()
|
log.Infof("setting SSH server to false by default on Android")
|
||||||
|
config.ServerSSHAllowed = util.False()
|
||||||
|
} else {
|
||||||
|
// enables SSH for configs from old versions to preserve backwards compatibility
|
||||||
|
log.Infof("falling back to enabled SSH server for pre-existing configuration")
|
||||||
|
config.ServerSSHAllowed = util.True()
|
||||||
|
}
|
||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,6 +487,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
|
||||||
|
if *input.BlockInbound {
|
||||||
|
log.Infof("blocking inbound connections")
|
||||||
|
} else {
|
||||||
|
log.Infof("allowing inbound connections")
|
||||||
|
}
|
||||||
|
config.BlockInbound = *input.BlockInbound
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||||
if *input.DisableNotifications {
|
if *input.DisableNotifications {
|
||||||
log.Infof("disabling notifications")
|
log.Infof("disabling notifications")
|
||||||
@@ -524,6 +542,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
|
||||||
|
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
|
||||||
|
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
326
client/internal/conn_mgr.go
Normal file
326
client/internal/conn_mgr.go
Normal file
@@ -0,0 +1,326 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
"github.com/netbirdio/netbird/route"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
|
||||||
|
//
|
||||||
|
// The connection manager is responsible for:
|
||||||
|
// - Managing lazy connections via the lazyConnManager
|
||||||
|
// - Maintaining a list of excluded peers that should always have permanent connections
|
||||||
|
// - Handling connection establishment based on peer signaling
|
||||||
|
//
|
||||||
|
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
|
||||||
|
type ConnMgr struct {
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
iface lazyconn.WGIface
|
||||||
|
enabledLocally bool
|
||||||
|
rosenpassEnabled bool
|
||||||
|
|
||||||
|
lazyConnMgr *manager.Manager
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
lazyCtx context.Context
|
||||||
|
lazyCtxCancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface) *ConnMgr {
|
||||||
|
e := &ConnMgr{
|
||||||
|
peerStore: peerStore,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
iface: iface,
|
||||||
|
rosenpassEnabled: engineConfig.RosenpassEnabled,
|
||||||
|
}
|
||||||
|
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||||
|
e.enabledLocally = true
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
|
||||||
|
func (e *ConnMgr) Start(ctx context.Context) {
|
||||||
|
if e.lazyConnMgr != nil {
|
||||||
|
log.Errorf("lazy connection manager is already started")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.enabledLocally {
|
||||||
|
log.Infof("lazy connection manager is disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.rosenpassEnabled {
|
||||||
|
log.Warnf("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.initLazyManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
|
||||||
|
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
|
||||||
|
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
|
||||||
|
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
|
||||||
|
// do not disable lazy connection manager if it was enabled by env var
|
||||||
|
if e.enabledLocally {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if enabled {
|
||||||
|
// if the lazy connection manager is already started, do not start it again
|
||||||
|
if e.lazyConnMgr != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if e.rosenpassEnabled {
|
||||||
|
log.Infof("rosenpass connection manager is enabled, lazy connection manager will not be started")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Warnf("lazy connection manager is enabled by management feature flag")
|
||||||
|
e.initLazyManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
|
return e.addPeersToLazyConnManager()
|
||||||
|
} else {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Infof("lazy connection manager is disabled by management feature flag")
|
||||||
|
e.closeManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(false)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateRouteHAMap updates the route HA mappings in the lazy connection manager
|
||||||
|
func (e *ConnMgr) UpdateRouteHAMap(haMap route.HAMap) {
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
log.Debugf("lazy connection manager is not started, skipping UpdateRouteHAMap")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyConnMgr.UpdateRouteHAMap(haMap)
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
|
||||||
|
func (e *ConnMgr) SetExcludeList(ctx context.Context, peerIDs map[string]bool) {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
|
||||||
|
|
||||||
|
for peerID := range peerIDs {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: peerConn.ConnID(),
|
||||||
|
Log: peerConn.Log,
|
||||||
|
}
|
||||||
|
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
added := e.lazyConnMgr.ExcludePeer(excludedPeers)
|
||||||
|
for _, peerID := range added {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
|
||||||
|
if err := peerConn.Open(ctx); err != nil {
|
||||||
|
peerConn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
|
||||||
|
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !lazyconn.IsSupported(conn.AgentVersionString()) {
|
||||||
|
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerKey,
|
||||||
|
AllowedIPs: conn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: conn.ConnID(),
|
||||||
|
Log: conn.Log,
|
||||||
|
}
|
||||||
|
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
|
||||||
|
if err != nil {
|
||||||
|
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if excluded {
|
||||||
|
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Log.Infof("peer added to lazy conn manager")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||||
|
conn, ok := e.peerStore.Remove(peerKey)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close(false)
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyConnMgr.RemovePeer(peerKey)
|
||||||
|
conn.Log.Infof("removed peer from lazy conn manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) ActivatePeer(ctx context.Context, conn *peer.Conn) {
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if found := e.lazyConnMgr.ActivatePeer(conn.GetKey()); found {
|
||||||
|
conn.Log.Infof("activated peer from inactive state")
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeactivatePeer deactivates a peer connection in the lazy connection manager.
|
||||||
|
// If locally the lazy connection is disabled, we force the peer connection open.
|
||||||
|
func (e *ConnMgr) DeactivatePeer(conn *peer.Conn) {
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Log.Infof("closing peer connection: remote peer initiated inactive, idle lazy state and sent GOAWAY")
|
||||||
|
e.lazyConnMgr.DeactivatePeer(conn.ConnID())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) Close() {
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyCtxCancel()
|
||||||
|
e.wg.Wait()
|
||||||
|
e.lazyConnMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) initLazyManager(engineCtx context.Context) {
|
||||||
|
cfg := manager.Config{
|
||||||
|
InactivityThreshold: inactivityThresholdEnv(),
|
||||||
|
}
|
||||||
|
e.lazyConnMgr = manager.NewManager(cfg, engineCtx, e.peerStore, e.iface)
|
||||||
|
|
||||||
|
e.lazyCtx, e.lazyCtxCancel = context.WithCancel(engineCtx)
|
||||||
|
|
||||||
|
e.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer e.wg.Done()
|
||||||
|
e.lazyConnMgr.Start(e.lazyCtx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) addPeersToLazyConnManager() error {
|
||||||
|
peers := e.peerStore.PeersPubKey()
|
||||||
|
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
|
||||||
|
for _, peerID := range peers {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: peerConn.ConnID(),
|
||||||
|
Log: peerConn.Log,
|
||||||
|
}
|
||||||
|
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.lazyConnMgr.AddActivePeers(lazyPeerCfgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) closeManager(ctx context.Context) {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyCtxCancel()
|
||||||
|
e.wg.Wait()
|
||||||
|
e.lazyConnMgr = nil
|
||||||
|
|
||||||
|
for _, peerID := range e.peerStore.PeersPubKey() {
|
||||||
|
e.peerStore.PeerConnOpen(ctx, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) isStartedWithLazyMgr() bool {
|
||||||
|
return e.lazyConnMgr != nil && e.lazyCtxCancel != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func inactivityThresholdEnv() *time.Duration {
|
||||||
|
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
|
||||||
|
if envValue == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedMinutes, err := strconv.Atoi(envValue)
|
||||||
|
if err != nil || parsedMinutes <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d := time.Duration(parsedMinutes) * time.Minute
|
||||||
|
return &d
|
||||||
|
}
|
||||||
@@ -17,7 +17,6 @@ import (
|
|||||||
"google.golang.org/grpc/codes"
|
"google.golang.org/grpc/codes"
|
||||||
gstatus "google.golang.org/grpc/status"
|
gstatus "google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
@@ -349,6 +348,25 @@ func (c *ConnectClient) Engine() *Engine {
|
|||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLatestNetworkMap returns the latest network map from the engine.
|
||||||
|
func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||||
|
engine := c.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, errors.New("engine is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
networkMap, err := engine.GetLatestNetworkMap()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get latest network map: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if networkMap == nil {
|
||||||
|
return nil, errors.New("network map is not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return networkMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Status returns the current client status
|
// Status returns the current client status
|
||||||
func (c *ConnectClient) Status() StatusType {
|
func (c *ConnectClient) Status() StatusType {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
@@ -417,11 +435,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
DNSRouteInterval: config.DNSRouteInterval,
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
DisableClientRoutes: config.DisableClientRoutes,
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
DisableServerRoutes: config.DisableServerRoutes,
|
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||||
DisableDNS: config.DisableDNS,
|
DisableDNS: config.DisableDNS,
|
||||||
DisableFirewall: config.DisableFirewall,
|
DisableFirewall: config.DisableFirewall,
|
||||||
|
BlockLANAccess: config.BlockLANAccess,
|
||||||
|
BlockInbound: config.BlockInbound,
|
||||||
|
|
||||||
BlockLANAccess: config.BlockLANAccess,
|
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
@@ -462,7 +482,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
|||||||
return signalClient, nil
|
return signalClient, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
@@ -479,6 +499,9 @@ func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte,
|
|||||||
config.DisableServerRoutes,
|
config.DisableServerRoutes,
|
||||||
config.DisableDNS,
|
config.DisableDNS,
|
||||||
config.DisableFirewall,
|
config.DisableFirewall,
|
||||||
|
config.BlockLANAccess,
|
||||||
|
config.BlockInbound,
|
||||||
|
config.LazyConnectionEnabled,
|
||||||
)
|
)
|
||||||
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
loginResp, err := client.Login(*serverPublicKey, sysInfo, pubSSHKey, config.DNSLabels)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -502,17 +525,13 @@ func statusRecorderToSignalConnStateNotifier(statusRecorder *peer.Status) signal
|
|||||||
|
|
||||||
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
|
// freePort attempts to determine if the provided port is available, if not it will ask the system for a free port.
|
||||||
func freePort(initPort int) (int, error) {
|
func freePort(initPort int) (int, error) {
|
||||||
addr := net.UDPAddr{}
|
addr := net.UDPAddr{Port: initPort}
|
||||||
if initPort == 0 {
|
|
||||||
initPort = iface.DefaultWgPort
|
|
||||||
}
|
|
||||||
|
|
||||||
addr.Port = initPort
|
|
||||||
|
|
||||||
conn, err := net.ListenUDP("udp", &addr)
|
conn, err := net.ListenUDP("udp", &addr)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
|
returnPort := conn.LocalAddr().(*net.UDPAddr).Port
|
||||||
closeConnWithLog(conn)
|
closeConnWithLog(conn)
|
||||||
return initPort, nil
|
return returnPort, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// if the port is already in use, ask the system for a free port
|
// if the port is already in use, ask the system for a free port
|
||||||
|
|||||||
@@ -13,10 +13,10 @@ func Test_freePort(t *testing.T) {
|
|||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "not provided, fallback to default",
|
name: "when port is 0 use random port",
|
||||||
port: 0,
|
port: 0,
|
||||||
want: 51820,
|
want: 0,
|
||||||
shouldMatch: true,
|
shouldMatch: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "provided and available",
|
name: "provided and available",
|
||||||
@@ -31,7 +31,7 @@ func Test_freePort(t *testing.T) {
|
|||||||
shouldMatch: false,
|
shouldMatch: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 51830})
|
c1, err := net.ListenUDP("udp", &net.UDPAddr{Port: 0})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("freePort error = %v", err)
|
t.Errorf("freePort error = %v", err)
|
||||||
}
|
}
|
||||||
@@ -39,6 +39,14 @@ func Test_freePort(t *testing.T) {
|
|||||||
_ = c1.Close()
|
_ = c1.Close()
|
||||||
}(c1)
|
}(c1)
|
||||||
|
|
||||||
|
if tests[1].port == c1.LocalAddr().(*net.UDPAddr).Port {
|
||||||
|
tests[1].port++
|
||||||
|
tests[1].want++
|
||||||
|
}
|
||||||
|
|
||||||
|
tests[2].port = c1.LocalAddr().(*net.UDPAddr).Port
|
||||||
|
tests[2].want = c1.LocalAddr().(*net.UDPAddr).Port
|
||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
|
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
|||||||
1112
client/internal/debug/debug.go
Normal file
1112
client/internal/debug/debug.go
Normal file
File diff suppressed because it is too large
Load Diff
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user