mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-22 18:26:41 +00:00
Compare commits
158 Commits
v0.39.2
...
add-ns-pun
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
ff5eddf70b | ||
|
|
0f050e5fe1 | ||
|
|
0f7c7f1da2 | ||
|
|
b56f61bf1b | ||
|
|
64f111923e | ||
|
|
122a89c02b | ||
|
|
c6cceba381 | ||
|
|
6c0cdb6ed1 | ||
|
|
84354951d3 | ||
|
|
55957a1960 | ||
|
|
df82a45d99 | ||
|
|
9424b88db2 | ||
|
|
609654eee7 | ||
|
|
b604c66140 | ||
|
|
ea4d13e96d | ||
|
|
87148c503f | ||
|
|
0cd36baf67 | ||
|
|
06980e7fa0 | ||
|
|
1ce4ee0cef | ||
|
|
f367925496 | ||
|
|
616b19c064 | ||
|
|
af27aaf9af | ||
|
|
35287f8241 | ||
|
|
07b220d91b | ||
|
|
41cd4952f1 | ||
|
|
f16f0c7831 | ||
|
|
aa07b3b87b | ||
|
|
2bef214cc0 | ||
|
|
cfb2d82352 | ||
|
|
684501fd35 | ||
|
|
0492c1724a | ||
|
|
6f436e57b5 | ||
|
|
a0d28f9851 | ||
|
|
cdd27a9fe5 | ||
|
|
5523040acd | ||
|
|
670446d42e | ||
|
|
273160c682 | ||
|
|
5bed6777d5 | ||
|
|
a0482ebc7b | ||
|
|
1d6c360aec | ||
|
|
f04e7c3f06 | ||
|
|
2a89d6e47a | ||
|
|
3d89cd43c2 | ||
|
|
0eeda712d0 | ||
|
|
3e3268db5f | ||
|
|
31f0879e71 | ||
|
|
f25b5bb987 | ||
|
|
24f932b2ce | ||
|
|
c03435061c | ||
|
|
8e948739f1 | ||
|
|
9b53cad752 | ||
|
|
802a18167c | ||
|
|
e9108ffe6c | ||
|
|
e806d9de38 | ||
|
|
daa8380df9 | ||
|
|
4785f23fc4 | ||
|
|
1d4cfb83e7 | ||
|
|
207fa059d2 | ||
|
|
cbcdad7814 | ||
|
|
701c13807a | ||
|
|
99f8dc7748 | ||
|
|
f1de8e6eb0 | ||
|
|
b2a10780af | ||
|
|
43ae79d848 | ||
|
|
e520b64c6d | ||
|
|
92c91bbdd8 | ||
|
|
adf494e1ac | ||
|
|
2158461121 | ||
|
|
0cd4b601c3 | ||
|
|
ee1cec47b3 | ||
|
|
efb0edfc4c | ||
|
|
20f59ddecb | ||
|
|
2f34e984b0 | ||
|
|
d5b52e86b6 | ||
|
|
cad2fe1f39 | ||
|
|
fcd2c15a37 | ||
|
|
ebda0fc538 | ||
|
|
ac135ab11d | ||
|
|
25faf9283d | ||
|
|
59faaa99f6 | ||
|
|
9762b39f29 | ||
|
|
ffdd115ded | ||
|
|
055df9854c | ||
|
|
12f883badf | ||
|
|
2abb92b0d4 | ||
|
|
01c3719c5d | ||
|
|
7b64953eed | ||
|
|
9bc7d788f0 | ||
|
|
b5419ef11a | ||
|
|
d5081cef90 | ||
|
|
488e619ec7 | ||
|
|
d2b42c8f68 | ||
|
|
2f44fe2e23 | ||
|
|
d8dc107bee | ||
|
|
3fa915e271 | ||
|
|
47c3afe561 | ||
|
|
84bfecdd37 | ||
|
|
3cf87b6846 | ||
|
|
4fe4c2054d | ||
|
|
38ada44a0e | ||
|
|
dbf81a145e | ||
|
|
39483f8ca8 | ||
|
|
c0eaea938e | ||
|
|
ef8b8a2891 | ||
|
|
2817f62c13 | ||
|
|
4a9049566a | ||
|
|
85f92f8321 | ||
|
|
714beb6e3b | ||
|
|
400b9fca32 | ||
|
|
4013298e22 | ||
|
|
312bfd9bd7 | ||
|
|
8db05838ca | ||
|
|
c69df13515 | ||
|
|
986eb8c1e0 | ||
|
|
197761ba4d | ||
|
|
f74ea64c7b | ||
|
|
3b7b9d25bc | ||
|
|
1a6d6b3109 | ||
|
|
f686615876 | ||
|
|
a4311f574d | ||
|
|
0bb8eae903 | ||
|
|
e0b33d325d | ||
|
|
c38e07d89a | ||
|
|
a37368fff4 | ||
|
|
0c93bd3d06 | ||
|
|
a675531b5c | ||
|
|
7cb366bc7d | ||
|
|
a354004564 | ||
|
|
75bdd47dfb | ||
|
|
b165f63327 | ||
|
|
51bb52cdf5 | ||
|
|
4134b857b4 | ||
|
|
7839d2c169 | ||
|
|
b9f82e2f8a | ||
|
|
fd2a21c65d | ||
|
|
82d982b0ab | ||
|
|
9e24fe7701 | ||
|
|
e470701b80 | ||
|
|
e3ce026355 | ||
|
|
5ea2806663 | ||
|
|
d6b0673580 | ||
|
|
14913cfa7a | ||
|
|
03f600b576 | ||
|
|
192c97aa63 | ||
|
|
4db78db49a | ||
|
|
87e600a4f3 | ||
|
|
6162aeb82d | ||
|
|
1ba1e092ce | ||
|
|
86dbb4ee4f | ||
|
|
4af177215f | ||
|
|
df9c1b9883 | ||
|
|
5752bb78f2 | ||
|
|
fbd783ad58 | ||
|
|
80702b9323 | ||
|
|
09243a0fe0 | ||
|
|
3658215747 | ||
|
|
48ffec95dd | ||
|
|
cbec7bda80 |
27
.git-branches.toml
Normal file
27
.git-branches.toml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# More info around this file at https://www.git-town.com/configuration-file
|
||||||
|
|
||||||
|
[branches]
|
||||||
|
main = "main"
|
||||||
|
perennials = []
|
||||||
|
perennial-regex = ""
|
||||||
|
|
||||||
|
[create]
|
||||||
|
new-branch-type = "feature"
|
||||||
|
push-new-branches = false
|
||||||
|
|
||||||
|
[hosting]
|
||||||
|
dev-remote = "origin"
|
||||||
|
# platform = ""
|
||||||
|
# origin-hostname = ""
|
||||||
|
|
||||||
|
[ship]
|
||||||
|
delete-tracking-branch = false
|
||||||
|
strategy = "squash-merge"
|
||||||
|
|
||||||
|
[sync]
|
||||||
|
feature-strategy = "merge"
|
||||||
|
perennial-strategy = "rebase"
|
||||||
|
prototype-strategy = "merge"
|
||||||
|
push-hook = true
|
||||||
|
tags = true
|
||||||
|
upstream = false
|
||||||
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
15
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -37,17 +37,22 @@ If yes, which one?
|
|||||||
|
|
||||||
**Debug output**
|
**Debug output**
|
||||||
|
|
||||||
To help us resolve the problem, please attach the following debug output
|
To help us resolve the problem, please attach the following anonymized status output
|
||||||
|
|
||||||
netbird status -dA
|
netbird status -dA
|
||||||
|
|
||||||
As well as the file created by
|
Create and upload a debug bundle, and share the returned file key:
|
||||||
|
|
||||||
|
netbird debug for 1m -AS -U
|
||||||
|
|
||||||
|
*Uploaded files are automatically deleted after 30 days.*
|
||||||
|
|
||||||
|
|
||||||
|
Alternatively, create the file only and attach it here manually:
|
||||||
|
|
||||||
netbird debug for 1m -AS
|
netbird debug for 1m -AS
|
||||||
|
|
||||||
|
|
||||||
We advise reviewing the anonymized output for any remaining personal information.
|
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
|
|
||||||
If applicable, add screenshots to help explain your problem.
|
If applicable, add screenshots to help explain your problem.
|
||||||
@@ -57,8 +62,10 @@ If applicable, add screenshots to help explain your problem.
|
|||||||
Add any other context about the problem here.
|
Add any other context about the problem here.
|
||||||
|
|
||||||
**Have you tried these troubleshooting steps?**
|
**Have you tried these troubleshooting steps?**
|
||||||
|
- [ ] Reviewed [client troubleshooting](https://docs.netbird.io/how-to/troubleshooting-client) (if applicable)
|
||||||
- [ ] Checked for newer NetBird versions
|
- [ ] Checked for newer NetBird versions
|
||||||
- [ ] Searched for similar issues on GitHub (including closed ones)
|
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||||
- [ ] Restarted the NetBird client
|
- [ ] Restarted the NetBird client
|
||||||
- [ ] Disabled other VPN software
|
- [ ] Disabled other VPN software
|
||||||
- [ ] Checked firewall settings
|
- [ ] Checked firewall settings
|
||||||
|
|
||||||
|
|||||||
6
.github/pull_request_template.md
vendored
6
.github/pull_request_template.md
vendored
@@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
## Issue ticket number and link
|
## Issue ticket number and link
|
||||||
|
|
||||||
|
## Stack
|
||||||
|
|
||||||
|
<!-- branch-stack -->
|
||||||
|
|
||||||
### Checklist
|
### Checklist
|
||||||
- [ ] Is it a bug fix
|
- [ ] Is it a bug fix
|
||||||
- [ ] Is a typo/documentation fix
|
- [ ] Is a typo/documentation fix
|
||||||
@@ -9,3 +13,5 @@
|
|||||||
- [ ] It is a refactor
|
- [ ] It is a refactor
|
||||||
- [ ] Created tests that fail without the change (if possible)
|
- [ ] Created tests that fail without the change (if possible)
|
||||||
- [ ] Extended the README / documentation, if necessary
|
- [ ] Extended the README / documentation, if necessary
|
||||||
|
|
||||||
|
> By submitting this pull request, you confirm that you have read and agree to the terms of the [Contributor License Agreement](https://github.com/netbirdio/netbird/blob/main/CONTRIBUTOR_LICENSE_AGREEMENT.md).
|
||||||
|
|||||||
21
.github/workflows/git-town.yml
vendored
Normal file
21
.github/workflows/git-town.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: Git Town
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- '**'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
git-town:
|
||||||
|
name: Display the branch stack
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: git-town/action@v1
|
||||||
|
with:
|
||||||
|
skip-single-stacks: true
|
||||||
10
.github/workflows/golang-test-freebsd.yml
vendored
10
.github/workflows/golang-test-freebsd.yml
vendored
@@ -22,14 +22,20 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
usesh: true
|
usesh: true
|
||||||
copyback: false
|
copyback: false
|
||||||
release: "14.1"
|
release: "14.2"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y go pkgconf xorg
|
pkg install -y curl pkgconf xorg
|
||||||
|
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
|
||||||
|
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
|
||||||
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
|
curl -vLO "$GO_URL"
|
||||||
|
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||||
|
|
||||||
# -x - to print all executed commands
|
# -x - to print all executed commands
|
||||||
# -e - to faile on first error
|
# -e - to faile on first error
|
||||||
run: |
|
run: |
|
||||||
set -e -x
|
set -e -x
|
||||||
|
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||||
time go build -o netbird client/main.go
|
time go build -o netbird client/main.go
|
||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -timeout 1m -failfast ./base62/...
|
||||||
|
|||||||
233
.github/workflows/golang-test-linux.yml
vendored
233
.github/workflows/golang-test-linux.yml
vendored
@@ -146,6 +146,65 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||||
|
|
||||||
|
test_client_on_docker:
|
||||||
|
name: "Client (Docker) / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
id: go-env
|
||||||
|
run: |
|
||||||
|
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
id: cache-restore
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ steps.go-env.outputs.cache_dir }}
|
||||||
|
${{ steps.go-env.outputs.modcache_dir }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Run tests in container
|
||||||
|
env:
|
||||||
|
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
|
||||||
|
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
|
||||||
|
run: |
|
||||||
|
CONTAINER_GOCACHE="/root/.cache/go-build"
|
||||||
|
CONTAINER_GOMODCACHE="/go/pkg/mod"
|
||||||
|
|
||||||
|
docker run --rm \
|
||||||
|
--cap-add=NET_ADMIN \
|
||||||
|
--privileged \
|
||||||
|
-v $PWD:/app \
|
||||||
|
-w /app \
|
||||||
|
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
|
||||||
|
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
|
||||||
|
-e CGO_ENABLED=1 \
|
||||||
|
-e CI=true \
|
||||||
|
-e DOCKER_CI=true \
|
||||||
|
-e GOARCH=${GOARCH_TARGET} \
|
||||||
|
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||||
|
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||||
|
golang:1.23-alpine \
|
||||||
|
sh -c ' \
|
||||||
|
apk update; apk add --no-cache \
|
||||||
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui -e /upload-server)
|
||||||
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
name: "Relay / Unit"
|
name: "Relay / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -164,6 +223,10 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -179,13 +242,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -217,6 +273,10 @@ jobs:
|
|||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo apt update && sudo apt install -y gcc-multilib g++-multilib libc6-dev-i386
|
||||||
|
|
||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
@@ -232,13 +292,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -286,13 +339,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -314,6 +360,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags=devcert \
|
go test -tags=devcert \
|
||||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
@@ -353,13 +400,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -380,10 +420,11 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags devcert -run=^$ -bench=. \
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
-timeout 20m ./...
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
api_benchmark:
|
api_benchmark:
|
||||||
name: "Management / Benchmark (API)"
|
name: "Management / Benchmark (API)"
|
||||||
@@ -396,6 +437,33 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres' ]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Create Docker network
|
||||||
|
run: docker network create promnet
|
||||||
|
|
||||||
|
- name: Start Prometheus Pushgateway
|
||||||
|
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
|
||||||
|
|
||||||
|
- name: Start Prometheus (for Pushgateway forwarding)
|
||||||
|
run: |
|
||||||
|
echo '
|
||||||
|
global:
|
||||||
|
scrape_interval: 15s
|
||||||
|
scrape_configs:
|
||||||
|
- job_name: "pushgateway"
|
||||||
|
static_configs:
|
||||||
|
- targets: ["pushgateway:9091"]
|
||||||
|
remote_write:
|
||||||
|
- url: ${{ secrets.GRAFANA_URL }}
|
||||||
|
basic_auth:
|
||||||
|
username: ${{ secrets.GRAFANA_USER }}
|
||||||
|
password: ${{ secrets.GRAFANA_API_KEY }}
|
||||||
|
' > prometheus.yml
|
||||||
|
|
||||||
|
docker run -d --name prometheus --network promnet \
|
||||||
|
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||||
|
-p 9090:9090 \
|
||||||
|
prom/prometheus
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
@@ -420,13 +488,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -447,11 +508,13 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
|
GIT_BRANCH=${{ github.ref_name }} \
|
||||||
go test -tags=benchmark \
|
go test -tags=benchmark \
|
||||||
-run=^$ \
|
-run=^$ \
|
||||||
-bench=. \
|
-bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
@@ -489,13 +552,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -505,89 +561,8 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags=integration \
|
go test -tags=integration \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
test_client_on_docker:
|
|
||||||
name: "Client (Docker) / Unit"
|
|
||||||
needs: [ build-cache ]
|
|
||||||
runs-on: ubuntu-20.04
|
|
||||||
steps:
|
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version: "1.23.x"
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
|
||||||
run: |
|
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache/restore@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
${{ env.cache }}
|
|
||||||
${{ env.modcache }}
|
|
||||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-gotest-cache-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install modules
|
|
||||||
run: go mod tidy
|
|
||||||
|
|
||||||
- name: check git status
|
|
||||||
run: git --no-pager diff --exit-code
|
|
||||||
|
|
||||||
- name: Generate Shared Sock Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
|
||||||
|
|
||||||
- name: Generate RouteManager Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
|
||||||
|
|
||||||
- name: Generate SystemOps Test bin
|
|
||||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
|
||||||
|
|
||||||
- name: Generate nftables Manager Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
|
||||||
|
|
||||||
- name: Generate Engine Test bin
|
|
||||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
|
||||||
|
|
||||||
- name: Generate Peer Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
|
|
||||||
|
|
||||||
- run: chmod +x *testing.bin
|
|
||||||
|
|
||||||
- name: Run Shared Sock tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Iface tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
|
||||||
|
|
||||||
- name: Run RouteManager tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run SystemOps tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run nftables Manager tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Engine tests in docker with file store
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Engine tests in docker with sqlite store
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Peer tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|||||||
1
.github/workflows/golangci-lint.yml
vendored
1
.github/workflows/golangci-lint.yml
vendored
@@ -21,7 +21,6 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
|
||||||
golangci:
|
golangci:
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
|
|||||||
@@ -172,12 +172,14 @@ jobs:
|
|||||||
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||||
# check relay values
|
# check relay values
|
||||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
grep "NB_EXPOSED_ADDRESS=rels://$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||||
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
grep "NB_LISTEN_ADDRESS=:33445" docker-compose.yml
|
||||||
grep '33445:33445' docker-compose.yml
|
grep '33445:33445' docker-compose.yml
|
||||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
grep -A 7 Relay management.json | grep "rels://$CI_NETBIRD_DOMAIN:33445"
|
||||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||||
|
grep DisablePromptLogin management.json | grep 'true'
|
||||||
|
grep LoginFlag management.json | grep 0
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|||||||
@@ -96,6 +96,20 @@ builds:
|
|||||||
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
mod_timestamp: "{{ .CommitTimestamp }}"
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
|
- id: netbird-upload
|
||||||
|
dir: upload-server
|
||||||
|
env: [CGO_ENABLED=0]
|
||||||
|
binary: netbird-upload
|
||||||
|
goos:
|
||||||
|
- linux
|
||||||
|
goarch:
|
||||||
|
- amd64
|
||||||
|
- arm64
|
||||||
|
- arm
|
||||||
|
ldflags:
|
||||||
|
- -s -w -X github.com/netbirdio/netbird/version.version={{.Version}} -X main.commit={{.Commit}} -X main.date={{.CommitDate}} -X main.builtBy=goreleaser
|
||||||
|
mod_timestamp: "{{ .CommitTimestamp }}"
|
||||||
|
|
||||||
universal_binaries:
|
universal_binaries:
|
||||||
- id: netbird
|
- id: netbird
|
||||||
|
|
||||||
@@ -409,6 +423,52 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
ids:
|
||||||
|
- netbird-upload
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: upload-server/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird-upload
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: upload-server/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
|
ids:
|
||||||
|
- netbird-upload
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: upload-server/Dockerfile
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
docker_manifests:
|
docker_manifests:
|
||||||
- name_template: netbirdio/netbird:{{ .Version }}
|
- name_template: netbirdio/netbird:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
@@ -475,7 +535,17 @@ docker_manifests:
|
|||||||
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
- netbirdio/management:{{ .Version }}-debug-arm64v8
|
||||||
- netbirdio/management:{{ .Version }}-debug-arm
|
- netbirdio/management:{{ .Version }}-debug-arm
|
||||||
- netbirdio/management:{{ .Version }}-debug-amd64
|
- netbirdio/management:{{ .Version }}-debug-amd64
|
||||||
|
- name_template: netbirdio/upload:{{ .Version }}
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/upload:latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm64v8
|
||||||
|
- netbirdio/upload:{{ .Version }}-arm
|
||||||
|
- netbirdio/upload:{{ .Version }}-amd64
|
||||||
brews:
|
brews:
|
||||||
- ids:
|
- ids:
|
||||||
- default
|
- default
|
||||||
|
|||||||
20
README.md
20
README.md
@@ -12,7 +12,7 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
<a href="https://docs.netbird.io/slack-url">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
@@ -29,7 +29,7 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
Join our <a href="https://docs.netbird.io/slack-url">Slack channel</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
@@ -58,15 +58,15 @@
|
|||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
| Connectivity | Management | Security | Automation| Platforms |
|
| Connectivity | Management | Security | Automation| Platforms |
|
||||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
|----|----|----|----|----|
|
||||||
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||||
|
|
||||||
### Quickstart with NetBird Cloud
|
### Quickstart with NetBird Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
FROM alpine:3.21.3
|
FROM alpine:3.21.3
|
||||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
|
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||||
COPY netbird /usr/local/bin/netbird
|
COPY netbird /usr/local/bin/netbird
|
||||||
@@ -69,6 +69,22 @@ func (a *Anonymizer) AnonymizeIP(ip netip.Addr) netip.Addr {
|
|||||||
return a.ipAnonymizer[ip]
|
return a.ipAnonymizer[ip]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (a *Anonymizer) AnonymizeUDPAddr(addr net.UDPAddr) net.UDPAddr {
|
||||||
|
// Convert IP to netip.Addr
|
||||||
|
ip, ok := netip.AddrFromSlice(addr.IP)
|
||||||
|
if !ok {
|
||||||
|
return addr
|
||||||
|
}
|
||||||
|
|
||||||
|
anonIP := a.AnonymizeIP(ip)
|
||||||
|
|
||||||
|
return net.UDPAddr{
|
||||||
|
IP: anonIP.AsSlice(),
|
||||||
|
Port: addr.Port,
|
||||||
|
Zone: addr.Zone,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
// isInAnonymizedRange checks if an IP is within the range of already assigned anonymized IPs
|
||||||
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
func (a *Anonymizer) isInAnonymizedRange(ip netip.Addr) bool {
|
||||||
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
if ip.Is4() && ip.Compare(a.startAnonIPv4) >= 0 && ip.Compare(a.currentAnonIPv4) <= 0 {
|
||||||
|
|||||||
@@ -11,9 +11,12 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errCloseConnection = "Failed to close connection: %v"
|
const errCloseConnection = "Failed to close connection: %v"
|
||||||
@@ -84,16 +87,27 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
request := &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: getStatusOutput(cmd, anonymizeFlag),
|
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||||
SystemInfo: debugSystemInfoFlag,
|
SystemInfo: debugSystemInfoFlag,
|
||||||
})
|
}
|
||||||
|
if debugUploadBundle {
|
||||||
|
request.UploadURL = debugUploadBundleURL
|
||||||
|
}
|
||||||
|
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||||
|
|
||||||
cmd.Println(resp.GetPath())
|
if resp.GetUploadFailureReason() != "" {
|
||||||
|
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||||
|
}
|
||||||
|
|
||||||
|
if debugUploadBundle {
|
||||||
|
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -208,23 +222,19 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||||
|
request := &proto.DebugBundleRequest{
|
||||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: statusOutput,
|
Status: statusOutput,
|
||||||
SystemInfo: debugSystemInfoFlag,
|
SystemInfo: debugSystemInfoFlag,
|
||||||
})
|
}
|
||||||
|
if debugUploadBundle {
|
||||||
|
request.UploadURL = debugUploadBundleURL
|
||||||
|
}
|
||||||
|
resp, err := client.DebugBundle(cmd.Context(), request)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
|
||||||
// Disable network map persistence after creating the debug bundle
|
|
||||||
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
|
||||||
Enabled: false,
|
|
||||||
}); err != nil {
|
|
||||||
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
if stateWasDown {
|
if stateWasDown {
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||||
@@ -239,7 +249,15 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
cmd.Println("Log level restored to", initialLogLevel.GetLevel())
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Println(resp.GetPath())
|
cmd.Printf("Local file:\n%s\n", resp.GetPath())
|
||||||
|
|
||||||
|
if resp.GetUploadFailureReason() != "" {
|
||||||
|
return fmt.Errorf("upload failed: %s", resp.GetUploadFailureReason())
|
||||||
|
}
|
||||||
|
|
||||||
|
if debugUploadBundle {
|
||||||
|
cmd.Printf("Upload file key:\n%s\n", resp.GetUploadedKey())
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -326,3 +344,34 @@ func formatDuration(d time.Duration) string {
|
|||||||
s := d / time.Second
|
s := d / time.Second
|
||||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
|
||||||
|
var networkMap *mgmProto.NetworkMap
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if connectClient != nil {
|
||||||
|
networkMap, err = connectClient.GetLatestNetworkMap()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get latest network map: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bundleGenerator := debug.NewBundleGenerator(
|
||||||
|
debug.GeneratorDependencies{
|
||||||
|
InternalConfig: config,
|
||||||
|
StatusRecorder: recorder,
|
||||||
|
NetworkMap: networkMap,
|
||||||
|
LogFile: logFilePath,
|
||||||
|
},
|
||||||
|
debug.BundleConfig{
|
||||||
|
IncludeSystemInfo: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
path, err := bundleGenerator.Generate()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to generate debug bundle: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||||
|
}
|
||||||
|
|||||||
39
client/cmd/debug_unix.go
Normal file
39
client/cmd/debug_unix.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetupDebugHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
) {
|
||||||
|
usr1Ch := make(chan os.Signal, 1)
|
||||||
|
|
||||||
|
signal.Notify(usr1Ch, syscall.SIGUSR1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-usr1Ch:
|
||||||
|
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
|
||||||
|
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
126
client/cmd/debug_windows.go
Normal file
126
client/cmd/debug_windows.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
|
||||||
|
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
|
||||||
|
|
||||||
|
waitTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
|
||||||
|
// Example usage with PowerShell:
|
||||||
|
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
|
||||||
|
// $evt.Set()
|
||||||
|
// $evt.Close()
|
||||||
|
func SetupDebugHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
) {
|
||||||
|
env := os.Getenv(envListenEvent)
|
||||||
|
if env == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
listenEvent, err := strconv.ParseBool(env)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !listenEvent {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: restrict access by ACL
|
||||||
|
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
||||||
|
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
|
||||||
|
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
||||||
|
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
|
||||||
|
} else {
|
||||||
|
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventHandle == windows.InvalidHandle {
|
||||||
|
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
|
||||||
|
|
||||||
|
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
eventHandle windows.Handle,
|
||||||
|
) {
|
||||||
|
defer func() {
|
||||||
|
if err := windows.CloseHandle(eventHandle); err != nil {
|
||||||
|
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case windows.WAIT_OBJECT_0:
|
||||||
|
log.Info("Received signal on debug event. Triggering debug bundle generation.")
|
||||||
|
|
||||||
|
// reset the event so it can be triggered again later (manual reset == 1)
|
||||||
|
if err := windows.ResetEvent(eventHandle); err != nil {
|
||||||
|
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||||
|
case uint32(windows.WAIT_TIMEOUT):
|
||||||
|
|
||||||
|
default:
|
||||||
|
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
|
||||||
|
select {
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"os"
|
"os"
|
||||||
|
"runtime"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,6 +20,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
}
|
||||||
|
|
||||||
var loginCmd = &cobra.Command{
|
var loginCmd = &cobra.Command{
|
||||||
Use: "login",
|
Use: "login",
|
||||||
Short: "login to the Netbird Management Service (first run)",
|
Short: "login to the Netbird Management Service (first run)",
|
||||||
@@ -51,6 +56,9 @@ var loginCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update host's static platform and system information
|
||||||
|
system.UpdateStaticInfo()
|
||||||
|
|
||||||
ic := internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
@@ -93,7 +101,7 @@ var loginCmd = &cobra.Command{
|
|||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
DnsLabels: dnsLabelsReq,
|
DnsLabels: dnsLabelsReq,
|
||||||
}
|
}
|
||||||
@@ -127,7 +135,7 @@ var loginCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -188,7 +196,7 @@ func foregroundLogin(ctx context.Context, cmd *cobra.Command, config *internal.C
|
|||||||
}
|
}
|
||||||
|
|
||||||
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *internal.Config) (*auth.TokenInfo, error) {
|
||||||
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isLinuxRunningDesktop())
|
oAuthFlow, err := auth.NewOAuthFlow(ctx, config, isUnixRunningDesktop())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -198,7 +206,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||||
@@ -212,23 +220,34 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||||
var codeMsg string
|
var codeMsg string
|
||||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if noBrowser {
|
||||||
|
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
|
||||||
|
} else {
|
||||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||||
verificationURIComplete + " " + codeMsg)
|
verificationURIComplete + " " + codeMsg)
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
|
if !noBrowser {
|
||||||
if err := open.Run(verificationURIComplete); err != nil {
|
if err := open.Run(verificationURIComplete); err != nil {
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
// isUnixRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
func isLinuxRunningDesktop() bool {
|
func isUnixRunningDesktop() bool {
|
||||||
|
if runtime.GOOS != "linux" && runtime.GOOS != "freebsd" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
return os.Getenv("DESKTOP_SESSION") != "" || os.Getenv("XDG_CURRENT_DESKTOP") != ""
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,6 +22,7 @@ import (
|
|||||||
"google.golang.org/grpc/credentials/insecure"
|
"google.golang.org/grpc/credentials/insecure"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/upload-server/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -38,7 +39,9 @@ const (
|
|||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
systemInfoFlag = "system-info"
|
systemInfoFlag = "system-info"
|
||||||
blockLANAccessFlag = "block-lan-access"
|
enableLazyConnectionFlag = "enable-lazy-connection"
|
||||||
|
uploadBundle = "upload-bundle"
|
||||||
|
uploadBundleURL = "upload-bundle-url"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -74,7 +77,9 @@ var (
|
|||||||
anonymizeFlag bool
|
anonymizeFlag bool
|
||||||
debugSystemInfoFlag bool
|
debugSystemInfoFlag bool
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
blockLANAccess bool
|
debugUploadBundle bool
|
||||||
|
debugUploadBundleURL string
|
||||||
|
lazyConnEnabled bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
@@ -179,8 +184,11 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
upCmd.PersistentFlags().BoolVar(&rosenpassPermissive, rosenpassPermissiveFlag, false, "[Experimental] Enable Rosenpass in permissive mode to allow this peer to accept WireGuard connections without requiring Rosenpass functionality from peers that do not have Rosenpass enabled.")
|
||||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&lazyConnEnabled, enableLazyConnectionFlag, false, "[Experimental] Enable the lazy connection feature. If enabled, the client will establish connections on-demand.")
|
||||||
|
|
||||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||||
|
debugCmd.PersistentFlags().BoolVarP(&debugUploadBundle, uploadBundle, "U", false, fmt.Sprintf("Uploads the debug bundle to a server from URL defined by %s", uploadBundleURL))
|
||||||
|
debugCmd.PersistentFlags().StringVar(&debugUploadBundleURL, uploadBundleURL, types.DefaultBundleURL, "Service URL to get an URL to upload the debug bundle")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"runtime"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/kardianos/service"
|
"github.com/kardianos/service"
|
||||||
@@ -27,12 +28,19 @@ func newProgram(ctx context.Context, cancel context.CancelFunc) *program {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func newSVCConfig() *service.Config {
|
func newSVCConfig() *service.Config {
|
||||||
return &service.Config{
|
config := &service.Config{
|
||||||
Name: serviceName,
|
Name: serviceName,
|
||||||
DisplayName: "Netbird",
|
DisplayName: "Netbird",
|
||||||
Description: "A WireGuard-based mesh network that connects your devices into a single private network.",
|
Description: "Netbird mesh network client",
|
||||||
Option: make(service.KeyValue),
|
Option: make(service.KeyValue),
|
||||||
|
EnvVars: make(map[string]string),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if runtime.GOOS == "linux" {
|
||||||
|
config.EnvVars["SYSTEMD_UNIT"] = serviceName
|
||||||
|
}
|
||||||
|
|
||||||
|
return config
|
||||||
}
|
}
|
||||||
|
|
||||||
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
func newSVC(prg *program, conf *service.Config) (service.Service, error) {
|
||||||
|
|||||||
@@ -16,12 +16,17 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *program) Start(svc service.Service) error {
|
func (p *program) Start(svc service.Service) error {
|
||||||
// Start should not block. Do the actual work async.
|
// Start should not block. Do the actual work async.
|
||||||
log.Info("starting Netbird service") //nolint
|
log.Info("starting Netbird service") //nolint
|
||||||
|
|
||||||
|
// Collect static system and platform information
|
||||||
|
system.UpdateStaticInfo()
|
||||||
|
|
||||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||||
p.serv = grpc.NewServer()
|
p.serv = grpc.NewServer()
|
||||||
|
|
||||||
@@ -115,6 +120,7 @@ var runCmd = &cobra.Command{
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
|
SetupDebugHandler(ctx, nil, nil, nil, logFile)
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -39,7 +39,7 @@ var installCmd = &cobra.Command{
|
|||||||
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
svcConfig.Arguments = append(svcConfig.Arguments, "--management-url", managementURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
if logFile != "console" {
|
if logFile != "" {
|
||||||
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
svcConfig.Arguments = append(svcConfig.Arguments, "--log-file", logFile)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,7 @@ func init() {
|
|||||||
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
statusCmd.MarkFlagsMutuallyExclusive("detail", "json", "yaml", "ipv4")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
statusCmd.PersistentFlags().StringSliceVar(&ipsFilter, "filter-by-ips", []string{}, "filters the detailed output by a list of one or more IPs, e.g., --filter-by-ips 100.64.0.100,100.64.0.200")
|
||||||
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
statusCmd.PersistentFlags().StringSliceVar(&prefixNamesFilter, "filter-by-names", []string{}, "filters the detailed output by a list of one or more peer FQDN or hostnames, e.g., --filter-by-names peer-a,peer-b.netbird.cloud")
|
||||||
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(connected|disconnected), e.g., --filter-by-status connected")
|
statusCmd.PersistentFlags().StringVar(&statusFilter, "filter-by-status", "", "filters the detailed output by connection status(idle|connecting|connected), e.g., --filter-by-status connected")
|
||||||
}
|
}
|
||||||
|
|
||||||
func statusFunc(cmd *cobra.Command, args []string) error {
|
func statusFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -127,12 +127,12 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "disconnected", "connected":
|
case "", "idle", "connecting", "connected":
|
||||||
if strings.ToLower(statusFilter) != "" {
|
if strings.ToLower(statusFilter) != "" {
|
||||||
enableDetailFlagWhenFilterFlag()
|
enableDetailFlagWhenFilterFlag()
|
||||||
}
|
}
|
||||||
default:
|
default:
|
||||||
return fmt.Errorf("wrong status filter, should be one of connected|disconnected, got: %s", statusFilter)
|
return fmt.Errorf("wrong status filter, should be one of connected|connecting|idle, got: %s", statusFilter)
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(ipsFilter) > 0 {
|
if len(ipsFilter) > 0 {
|
||||||
|
|||||||
@@ -6,6 +6,8 @@ const (
|
|||||||
disableServerRoutesFlag = "disable-server-routes"
|
disableServerRoutesFlag = "disable-server-routes"
|
||||||
disableDNSFlag = "disable-dns"
|
disableDNSFlag = "disable-dns"
|
||||||
disableFirewallFlag = "disable-firewall"
|
disableFirewallFlag = "disable-firewall"
|
||||||
|
blockLANAccessFlag = "block-lan-access"
|
||||||
|
blockInboundFlag = "block-inbound"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -13,6 +15,8 @@ var (
|
|||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
disableDNS bool
|
disableDNS bool
|
||||||
disableFirewall bool
|
disableFirewall bool
|
||||||
|
blockLANAccess bool
|
||||||
|
blockInbound bool
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
func init() {
|
||||||
@@ -28,4 +32,11 @@ func init() {
|
|||||||
|
|
||||||
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||||
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false,
|
||||||
|
"Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&blockInbound, blockInboundFlag, false,
|
||||||
|
"Block inbound connections. If enabled, the client will not allow any inbound connections to the local machine nor routed networks.\n"+
|
||||||
|
"This overrides any policies received from the management service.")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
@@ -91,13 +92,18 @@ func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc
|
|||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ctrl := gomock.NewController(t)
|
ctrl := gomock.NewController(t)
|
||||||
t.Cleanup(ctrl.Finish)
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
settingsMockManager := settings.NewMockManager(ctrl)
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager)
|
settingsMockManager.EXPECT().
|
||||||
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.Settings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ var traceCmd = &cobra.Command{
|
|||||||
Example: `
|
Example: `
|
||||||
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||||
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||||
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --icmp-type 8 --icmp-code 0
|
||||||
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||||
Args: cobra.ExactArgs(3),
|
Args: cobra.ExactArgs(3),
|
||||||
RunE: tracePacket,
|
RunE: tracePacket,
|
||||||
|
|||||||
243
client/cmd/up.go
243
client/cmd/up.go
@@ -32,12 +32,16 @@ const (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
dnsLabelsFlag = "extra-dns-labels"
|
dnsLabelsFlag = "extra-dns-labels"
|
||||||
|
|
||||||
|
noBrowserFlag = "no-browser"
|
||||||
|
noBrowserDesc = "do not open the browser for SSO login"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
foregroundMode bool
|
foregroundMode bool
|
||||||
dnsLabels []string
|
dnsLabels []string
|
||||||
dnsLabelsValidated domain.List
|
dnsLabelsValidated domain.List
|
||||||
|
noBrowser bool
|
||||||
|
|
||||||
upCmd = &cobra.Command{
|
upCmd = &cobra.Command{
|
||||||
Use: "up",
|
Use: "up",
|
||||||
@@ -51,12 +55,11 @@ func init() {
|
|||||||
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
upCmd.PersistentFlags().StringVar(&interfaceName, interfaceNameFlag, iface.WgInterfaceDefault, "Wireguard interface name")
|
||||||
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
upCmd.PersistentFlags().Uint16Var(&wireguardPort, wireguardPortFlag, iface.DefaultWgPort, "Wireguard interface listening port")
|
||||||
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
upCmd.PersistentFlags().BoolVarP(&networkMonitor, networkMonitorFlag, "N", networkMonitor,
|
||||||
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux. `+
|
`Manage network monitoring. Defaults to true on Windows and macOS, false on Linux and FreeBSD. `+
|
||||||
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
`E.g. --network-monitor=false to disable or --network-monitor=true to enable.`,
|
||||||
)
|
)
|
||||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||||
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
|
||||||
|
|
||||||
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||||
`Sets DNS labels`+
|
`Sets DNS labels`+
|
||||||
@@ -65,6 +68,9 @@ func init() {
|
|||||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||||
`or --extra-dns-labels ""`,
|
`or --extra-dns-labels ""`,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func upFunc(cmd *cobra.Command, args []string) error {
|
func upFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -112,6 +118,124 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
ic, err := setupConfig(customDNSAddressConverted, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setup config: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
providedSetupKey, err := getSetupKey()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
config, err := internal.UpdateOrCreateConfig(*ic)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get config file: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
||||||
|
|
||||||
|
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("foreground login failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var cancel context.CancelFunc
|
||||||
|
ctx, cancel = context.WithCancel(ctx)
|
||||||
|
SetupCloseHandler(ctx, cancel)
|
||||||
|
|
||||||
|
r := peer.NewRecorder(config.ManagementURL.String())
|
||||||
|
r.GetFullStatus()
|
||||||
|
|
||||||
|
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||||
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
|
return connectClient.Run(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
|
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
||||||
|
"If the daemon is not running please run: "+
|
||||||
|
"\nnetbird service install \nnetbird service start\n", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err := conn.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
|
status, err := client.Status(ctx, &proto.StatusRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("unable to get daemon status: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if status.Status == string(internal.StatusConnected) {
|
||||||
|
cmd.Println("Already connected")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
providedSetupKey, err := getSetupKey()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get setup key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
loginRequest, err := setupLoginRequest(providedSetupKey, customDNSAddressConverted, cmd)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("setup login request: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var loginErr error
|
||||||
|
var loginResp *proto.LoginResponse
|
||||||
|
|
||||||
|
err = WithBackOff(func() error {
|
||||||
|
var backOffErr error
|
||||||
|
loginResp, backOffErr = client.Login(ctx, loginRequest)
|
||||||
|
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
||||||
|
s.Code() == codes.PermissionDenied ||
|
||||||
|
s.Code() == codes.NotFound ||
|
||||||
|
s.Code() == codes.Unimplemented) {
|
||||||
|
loginErr = backOffErr
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return backOffErr
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("login backoff cycle failed: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginErr != nil {
|
||||||
|
return fmt.Errorf("login failed: %v", loginErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
if loginResp.NeedsSSOLogin {
|
||||||
|
|
||||||
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("waiting sso login failed with: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
||||||
|
return fmt.Errorf("call service up method: %v", err)
|
||||||
|
}
|
||||||
|
cmd.Println("Connected")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupConfig(customDNSAddressConverted []byte, cmd *cobra.Command) (*internal.ConfigInput, error) {
|
||||||
ic := internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
@@ -136,7 +260,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
ic.InterfaceName = &interfaceName
|
ic.InterfaceName = &interfaceName
|
||||||
}
|
}
|
||||||
@@ -187,71 +311,17 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.BlockLANAccess = &blockLANAccess
|
ic.BlockLANAccess = &blockLANAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
if cmd.Flag(blockInboundFlag).Changed {
|
||||||
if err != nil {
|
ic.BlockInbound = &blockInbound
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
config, err := internal.UpdateOrCreateConfig(ic)
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
if err != nil {
|
ic.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
return fmt.Errorf("get config file: %v", err)
|
}
|
||||||
}
|
return &ic, nil
|
||||||
|
|
||||||
config, _ = internal.UpdateOldManagementURL(ctx, config, configPath)
|
|
||||||
|
|
||||||
err = foregroundLogin(ctx, cmd, config, providedSetupKey)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("foreground login failed: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
var cancel context.CancelFunc
|
|
||||||
ctx, cancel = context.WithCancel(ctx)
|
|
||||||
SetupCloseHandler(ctx, cancel)
|
|
||||||
|
|
||||||
r := peer.NewRecorder(config.ManagementURL.String())
|
|
||||||
r.GetFullStatus()
|
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
|
||||||
return connectClient.Run(nil)
|
|
||||||
}
|
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|
||||||
customDNSAddressConverted, err := parseCustomDNSAddress(cmd.Flag(dnsResolverAddress).Changed)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
conn, err := DialClientGRPCServer(ctx, daemonAddr)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to connect to daemon error: %v\n"+
|
|
||||||
"If the daemon is not running please run: "+
|
|
||||||
"\nnetbird service install \nnetbird service start\n", err)
|
|
||||||
}
|
|
||||||
defer func() {
|
|
||||||
err := conn.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Warnf("failed closing daemon gRPC client connection %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
|
|
||||||
status, err := client.Status(ctx, &proto.StatusRequest{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("unable to get daemon status: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if status.Status == string(internal.StatusConnected) {
|
|
||||||
cmd.Println("Already connected")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func setupLoginRequest(providedSetupKey string, customDNSAddressConverted []byte, cmd *cobra.Command) (*proto.LoginRequest, error) {
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
@@ -259,7 +329,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
NatExternalIPs: natExternalIPs,
|
NatExternalIPs: natExternalIPs,
|
||||||
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
CleanNATExternalIPs: natExternalIPs != nil && len(natExternalIPs) == 0,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsUnixDesktopClient: isUnixRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||||
DnsLabels: dnsLabels,
|
DnsLabels: dnsLabels,
|
||||||
@@ -288,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
if cmd.Flag(interfaceNameFlag).Changed {
|
if cmd.Flag(interfaceNameFlag).Changed {
|
||||||
if err := parseInterfaceName(interfaceName); err != nil {
|
if err := parseInterfaceName(interfaceName); err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
loginRequest.InterfaceName = &interfaceName
|
loginRequest.InterfaceName = &interfaceName
|
||||||
}
|
}
|
||||||
@@ -323,45 +393,14 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.BlockLanAccess = &blockLANAccess
|
loginRequest.BlockLanAccess = &blockLANAccess
|
||||||
}
|
}
|
||||||
|
|
||||||
var loginErr error
|
if cmd.Flag(blockInboundFlag).Changed {
|
||||||
|
loginRequest.BlockInbound = &blockInbound
|
||||||
var loginResp *proto.LoginResponse
|
|
||||||
|
|
||||||
err = WithBackOff(func() error {
|
|
||||||
var backOffErr error
|
|
||||||
loginResp, backOffErr = client.Login(ctx, &loginRequest)
|
|
||||||
if s, ok := gstatus.FromError(backOffErr); ok && (s.Code() == codes.InvalidArgument ||
|
|
||||||
s.Code() == codes.PermissionDenied ||
|
|
||||||
s.Code() == codes.NotFound ||
|
|
||||||
s.Code() == codes.Unimplemented) {
|
|
||||||
loginErr = backOffErr
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
return backOffErr
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("login backoff cycle failed: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if loginErr != nil {
|
if cmd.Flag(enableLazyConnectionFlag).Changed {
|
||||||
return fmt.Errorf("login failed: %v", loginErr)
|
loginRequest.LazyConnectionEnabled = &lazyConnEnabled
|
||||||
}
|
}
|
||||||
|
return &loginRequest, nil
|
||||||
if loginResp.NeedsSSOLogin {
|
|
||||||
|
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("waiting sso login failed with: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.Up(ctx, &proto.UpRequest{}); err != nil {
|
|
||||||
return fmt.Errorf("call service up method: %v", err)
|
|
||||||
}
|
|
||||||
cmd.Println("Connected")
|
|
||||||
return nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateNATExternalIPs(list []string) error {
|
func validateNATExternalIPs(list []string) error {
|
||||||
|
|||||||
@@ -113,17 +113,16 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if !destination.Addr().Is4() {
|
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
@@ -148,6 +147,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsStateful() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -199,7 +202,7 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
firewall.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
@@ -220,10 +223,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -243,6 +252,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return m.router.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateSet updates the set with the given prefixes
|
||||||
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.UpdateSet(set, prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package iptables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -19,11 +19,8 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -70,12 +67,12 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -95,9 +92,9 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
@@ -119,11 +116,8 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -144,11 +138,11 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := netip.MustParseAddr("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
rule2, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -186,11 +180,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: netip.MustParseAddr("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("10.20.0.0/24"),
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -212,11 +203,11 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
|
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
ip := netip.MustParseAddr("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,6 +41,8 @@ const (
|
|||||||
jumpManglePre = "jump-mangle-pre"
|
jumpManglePre = "jump-mangle-pre"
|
||||||
jumpNatPre = "jump-nat-pre"
|
jumpNatPre = "jump-nat-pre"
|
||||||
jumpNatPost = "jump-nat-post"
|
jumpNatPost = "jump-nat-post"
|
||||||
|
markManglePre = "mark-mangle-pre"
|
||||||
|
markManglePost = "mark-mangle-post"
|
||||||
matchSet = "--match-set"
|
matchSet = "--match-set"
|
||||||
|
|
||||||
dnatSuffix = "_dnat"
|
dnatSuffix = "_dnat"
|
||||||
@@ -55,18 +57,18 @@ type ruleInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type routeFilteringRuleParams struct {
|
type routeFilteringRuleParams struct {
|
||||||
Sources []netip.Prefix
|
Source firewall.Network
|
||||||
Destination netip.Prefix
|
Destination firewall.Network
|
||||||
Proto firewall.Protocol
|
Proto firewall.Protocol
|
||||||
SPort *firewall.Port
|
SPort *firewall.Port
|
||||||
DPort *firewall.Port
|
DPort *firewall.Port
|
||||||
Direction firewall.RuleDirection
|
Direction firewall.RuleDirection
|
||||||
Action firewall.Action
|
Action firewall.Action
|
||||||
SetName string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type routeRules map[string][]string
|
type routeRules map[string][]string
|
||||||
|
|
||||||
|
// the ipset library currently does not support comments, so we use the name only (string)
|
||||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
@@ -115,6 +117,10 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
|||||||
return fmt.Errorf("create containers: %w", err)
|
return fmt.Errorf("create containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.setupDataPlaneMark(); err != nil {
|
||||||
|
log.Errorf("failed to set up data plane mark: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
r.updateState()
|
r.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -123,7 +129,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
|||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
@@ -134,27 +140,28 @@ func (r *router) AddRouteFiltering(
|
|||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var setName string
|
var source firewall.Network
|
||||||
if len(sources) > 1 {
|
if len(sources) > 1 {
|
||||||
setName = firewall.GenerateSetName(sources)
|
source.Set = firewall.NewPrefixSet(sources)
|
||||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
} else if len(sources) > 0 {
|
||||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
source.Prefix = sources[0]
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
params := routeFilteringRuleParams{
|
params := routeFilteringRuleParams{
|
||||||
Sources: sources,
|
Source: source,
|
||||||
Destination: destination,
|
Destination: destination,
|
||||||
Proto: proto,
|
Proto: proto,
|
||||||
SPort: sPort,
|
SPort: sPort,
|
||||||
DPort: dPort,
|
DPort: dPort,
|
||||||
Action: action,
|
Action: action,
|
||||||
SetName: setName,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rule := genRouteFilteringRuleSpec(params)
|
rule, err := r.genRouteRuleSpec(params, sources)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate route rule spec: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
var err error
|
|
||||||
if action == firewall.ActionDrop {
|
if action == firewall.ActionDrop {
|
||||||
// after the established rule
|
// after the established rule
|
||||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||||
@@ -177,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
ruleKey := rule.ID()
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
setName := r.findSetNameInRule(rule)
|
|
||||||
|
|
||||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("delete route rule: %v", err)
|
return fmt.Errorf("delete route rule: %v", err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
if setName != "" {
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
return fmt.Errorf("failed to remove ipset: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("route rule %s not found", ruleKey)
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
@@ -198,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) findSetNameInRule(rule []string) string {
|
func (r *router) decrementSetCounter(rule []string) error {
|
||||||
|
sets := r.findSets(rule)
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, setName := range sets {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) findSets(rule []string) []string {
|
||||||
|
var sets []string
|
||||||
for i, arg := range rule {
|
for i, arg := range rule {
|
||||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||||
return rule[i+3]
|
sets = append(sets, rule[i+3])
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
return sets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||||
@@ -225,15 +241,13 @@ func (r *router) deleteIpSet(setName string) error {
|
|||||||
if err := ipset.Destroy(setName); err != nil {
|
if err := ipset.Destroy(setName); err != nil {
|
||||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("Deleted unused ipset %s", setName)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if r.legacyManagement {
|
if r.legacyManagement {
|
||||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
@@ -260,10 +274,7 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
if pair.Masquerade {
|
||||||
log.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -271,6 +282,7 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
@@ -307,8 +319,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
|
||||||
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -348,12 +362,16 @@ func (r *router) Reset() error {
|
|||||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
r.rules = make(map[string][]string)
|
|
||||||
|
|
||||||
if err := r.ipsetCounter.Flush(); err != nil {
|
if err := r.ipsetCounter.Flush(); err != nil {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules = make(map[string][]string)
|
||||||
r.updateState()
|
r.updateState()
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
@@ -423,6 +441,57 @@ func (r *router) createContainers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupDataPlaneMark configures the fwmark for the data plane
|
||||||
|
func (r *router) setupDataPlaneMark() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
preRule := []string{
|
||||||
|
"-i", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "NEW",
|
||||||
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
r.rules[markManglePre] = preRule
|
||||||
|
}
|
||||||
|
|
||||||
|
postRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "NEW",
|
||||||
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
r.rules[markManglePost] = postRule
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) cleanupDataPlaneMark() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
if preRule, exists := r.rules[markManglePre]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, markManglePre)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if postRule, exists := r.rules[markManglePost]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, markManglePost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) addPostroutingRules() error {
|
func (r *router) addPostroutingRules() error {
|
||||||
// First rule for outbound masquerade
|
// First rule for outbound masquerade
|
||||||
rule1 := []string{
|
rule1 := []string{
|
||||||
@@ -464,7 +533,7 @@ func (r *router) insertEstablishedRule(chain string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) addJumpRules() error {
|
func (r *router) addJumpRules() error {
|
||||||
// Jump to NAT chain
|
// Jump to nat chain
|
||||||
natRule := []string{"-j", chainRTNAT}
|
natRule := []string{"-j", chainRTNAT}
|
||||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||||
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||||
@@ -538,12 +607,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
rule = append(rule,
|
rule = append(rule,
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
"--ctstate", "NEW",
|
"--ctstate", "NEW",
|
||||||
"-s", pair.Source.String(),
|
)
|
||||||
"-d", pair.Destination.String(),
|
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply network -s: %w", err)
|
||||||
|
}
|
||||||
|
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply network -d: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule = append(rule, sourceExp...)
|
||||||
|
rule = append(rule, destExp...)
|
||||||
|
rule = append(rule,
|
||||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
// Ensure nat rules come first, so the mark can be overwritten.
|
||||||
|
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||||
|
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
|
||||||
|
// TODO: rollback ipset counter
|
||||||
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -561,6 +644,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("marking rule %s not found", ruleKey)
|
log.Debugf("marking rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
@@ -726,17 +813,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
|
||||||
var rule []string
|
var rule []string
|
||||||
|
|
||||||
if params.SetName != "" {
|
sourceExp, err := r.applyNetwork("-s", params.Source, sources)
|
||||||
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
if err != nil {
|
||||||
} else if len(params.Sources) > 0 {
|
return nil, fmt.Errorf("apply network -s: %w", err)
|
||||||
source := params.Sources[0]
|
|
||||||
rule = append(rule, "-s", source.String())
|
}
|
||||||
|
destExp, err := r.applyNetwork("-d", params.Destination, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rule = append(rule, "-d", params.Destination.String())
|
rule = append(rule, sourceExp...)
|
||||||
|
rule = append(rule, destExp...)
|
||||||
|
|
||||||
if params.Proto != firewall.ProtocolALL {
|
if params.Proto != firewall.ProtocolALL {
|
||||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||||
@@ -746,7 +837,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
|||||||
|
|
||||||
rule = append(rule, "-j", actionToStr(params.Action))
|
rule = append(rule, "-j", actionToStr(params.Action))
|
||||||
|
|
||||||
return rule
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||||
|
direction := "src"
|
||||||
|
if flag == "-d" {
|
||||||
|
direction = "dst"
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.IsSet() {
|
||||||
|
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||||
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||||
|
}
|
||||||
|
if network.IsPrefix() {
|
||||||
|
return []string{flag, network.Prefix.String()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:nilnil
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
// TODO: Implement IPv6 support
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if merr == nil {
|
||||||
|
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyPort(flag string, port *firewall.Port) []string {
|
func applyPort(flag string, port *firewall.Port) []string {
|
||||||
|
|||||||
@@ -46,7 +46,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
// 5. jump rule to PRE nat chain
|
// 5. jump rule to PRE nat chain
|
||||||
// 6. static outbound masquerade rule
|
// 6. static outbound masquerade rule
|
||||||
// 7. static return masquerade rule
|
// 7. static return masquerade rule
|
||||||
require.Len(t, manager.rules, 7, "should have created rules map")
|
// 8. mangle prerouting mark rule
|
||||||
|
// 9. mangle postrouting mark rule
|
||||||
|
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
@@ -58,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
|
|
||||||
pair := firewall.RouterPair{
|
pair := firewall.RouterPair{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
@@ -345,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
assert.NoError(t, err, "Failed to check rule existence")
|
assert.NoError(t, err, "Failed to check rule existence")
|
||||||
assert.True(t, exists, "Rule not found in iptables")
|
assert.True(t, exists, "Rule not found in iptables")
|
||||||
|
|
||||||
|
var source firewall.Network
|
||||||
|
if len(tt.sources) > 1 {
|
||||||
|
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||||
|
} else if len(tt.sources) > 0 {
|
||||||
|
source.Prefix = tt.sources[0]
|
||||||
|
}
|
||||||
// Verify rule content
|
// Verify rule content
|
||||||
params := routeFilteringRuleParams{
|
params := routeFilteringRuleParams{
|
||||||
Sources: tt.sources,
|
Source: source,
|
||||||
Destination: tt.destination,
|
Destination: firewall.Network{Prefix: tt.destination},
|
||||||
Proto: tt.proto,
|
Proto: tt.proto,
|
||||||
SPort: tt.sPort,
|
SPort: tt.sPort,
|
||||||
DPort: tt.dPort,
|
DPort: tt.dPort,
|
||||||
Action: tt.action,
|
Action: tt.action,
|
||||||
SetName: "",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedRule := genRouteFilteringRuleSpec(params)
|
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||||
|
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||||
|
|
||||||
if tt.expectSet {
|
if tt.expectSet {
|
||||||
setName := firewall.GenerateSetName(tt.sources)
|
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||||
params.SetName = setName
|
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||||
expectedRule = genRouteFilteringRuleSpec(params)
|
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||||
|
|
||||||
// Check if the set was created
|
// Check if the set was created
|
||||||
_, exists := r.ipsetCounter.Get(setName)
|
_, exists := r.ipsetCounter.Get(setName)
|
||||||
@@ -376,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindSetNameInRule(t *testing.T) {
|
||||||
|
r := &router{}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
rule []string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic rule with two sets",
|
||||||
|
rule: []string{
|
||||||
|
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
|
||||||
|
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
|
||||||
|
},
|
||||||
|
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No sets",
|
||||||
|
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple sets with different positions",
|
||||||
|
rule: []string{
|
||||||
|
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
|
||||||
|
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
|
||||||
|
},
|
||||||
|
expected: []string{"set1", "set-abc123"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Boundary case - sequence appears at end",
|
||||||
|
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
|
||||||
|
expected: []string{"final-set"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Incomplete pattern - missing set name",
|
||||||
|
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := r.findSets(tc.rule)
|
||||||
|
|
||||||
|
if len(result) != len(tc.expected) {
|
||||||
|
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, set := range result {
|
||||||
|
if set != tc.expected[i] {
|
||||||
|
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -43,6 +40,18 @@ const (
|
|||||||
// Action is the action to be taken on a rule
|
// Action is the action to be taken on a rule
|
||||||
type Action int
|
type Action int
|
||||||
|
|
||||||
|
// String returns the string representation of the action
|
||||||
|
func (a Action) String() string {
|
||||||
|
switch a {
|
||||||
|
case ActionAccept:
|
||||||
|
return "accept"
|
||||||
|
case ActionDrop:
|
||||||
|
return "drop"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ActionAccept is the action to accept a packet
|
// ActionAccept is the action to accept a packet
|
||||||
ActionAccept Action = iota
|
ActionAccept Action = iota
|
||||||
@@ -50,6 +59,33 @@ const (
|
|||||||
ActionDrop
|
ActionDrop
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Network is a rule destination, either a set or a prefix
|
||||||
|
type Network struct {
|
||||||
|
Set Set
|
||||||
|
Prefix netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the destination
|
||||||
|
func (d Network) String() string {
|
||||||
|
if d.Prefix.IsValid() {
|
||||||
|
return d.Prefix.String()
|
||||||
|
}
|
||||||
|
if d.IsSet() {
|
||||||
|
return d.Set.HashedName()
|
||||||
|
}
|
||||||
|
return "<invalid network>"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSet returns true if the destination is a set
|
||||||
|
func (d Network) IsSet() bool {
|
||||||
|
return d.Set != Set{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPrefix returns true if the destination is a valid prefix
|
||||||
|
func (d Network) IsPrefix() bool {
|
||||||
|
return d.Prefix.IsValid()
|
||||||
|
}
|
||||||
|
|
||||||
// Manager is the high level abstraction of a firewall manager
|
// Manager is the high level abstraction of a firewall manager
|
||||||
//
|
//
|
||||||
// It declares methods which handle actions required by the
|
// It declares methods which handle actions required by the
|
||||||
@@ -80,13 +116,14 @@ type Manager interface {
|
|||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
|
IsStateful() bool
|
||||||
|
|
||||||
AddRouteFiltering(
|
AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination Network,
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort, dPort *Port,
|
||||||
dPort *Port,
|
|
||||||
action Action,
|
action Action,
|
||||||
) (Rule, error)
|
) (Rule, error)
|
||||||
|
|
||||||
@@ -119,6 +156,9 @@ type Manager interface {
|
|||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
DeleteDNATRule(Rule) error
|
DeleteDNATRule(Rule) error
|
||||||
|
|
||||||
|
// UpdateSet updates the set with the given prefixes
|
||||||
|
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
@@ -153,22 +193,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
|
||||||
func GenerateSetName(sources []netip.Prefix) string {
|
|
||||||
// sort for consistent naming
|
|
||||||
SortPrefixes(sources)
|
|
||||||
|
|
||||||
var sourcesStr strings.Builder
|
|
||||||
for _, src := range sources {
|
|
||||||
sourcesStr.WriteString(src.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
|
||||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
|
||||||
|
|
||||||
return fmt.Sprintf("nb-%s", shortHash)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||||
if len(prefixes) == 0 {
|
if len(prefixes) == 0 {
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result1 := manager.GenerateSetName(prefixes1)
|
result1 := manager.NewPrefixSet(prefixes1)
|
||||||
result2 := manager.GenerateSetName(prefixes2)
|
result2 := manager.NewPrefixSet(prefixes2)
|
||||||
|
|
||||||
if result1 != result2 {
|
if result1 != result2 {
|
||||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||||
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
netip.MustParsePrefix("10.0.0.0/8"),
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result := manager.GenerateSetName(prefixes)
|
result := manager.NewPrefixSet(prefixes)
|
||||||
|
|
||||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error matching regex: %v", err)
|
t.Fatalf("Error matching regex: %v", err)
|
||||||
}
|
}
|
||||||
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
result1 := manager.NewPrefixSet([]netip.Prefix{})
|
||||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
result2 := manager.NewPrefixSet([]netip.Prefix{})
|
||||||
|
|
||||||
if result1 != result2 {
|
if result1 != result2 {
|
||||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||||
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result1 := manager.GenerateSetName(prefixes1)
|
result1 := manager.NewPrefixSet(prefixes1)
|
||||||
result2 := manager.GenerateSetName(prefixes2)
|
result2 := manager.NewPrefixSet(prefixes2)
|
||||||
|
|
||||||
if result1 != result2 {
|
if result1 != result2 {
|
||||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouterPair struct {
|
type RouterPair struct {
|
||||||
ID route.ID
|
ID route.ID
|
||||||
Source netip.Prefix
|
Source Network
|
||||||
Destination netip.Prefix
|
Destination Network
|
||||||
Masquerade bool
|
Masquerade bool
|
||||||
Inverse bool
|
Inverse bool
|
||||||
}
|
}
|
||||||
|
|||||||
74
client/firewall/manager/set.go
Normal file
74
client/firewall/manager/set.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Set struct {
|
||||||
|
hash [4]byte
|
||||||
|
comment string
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the set: hashed name and comment
|
||||||
|
func (h Set) String() string {
|
||||||
|
if h.comment == "" {
|
||||||
|
return h.HashedName()
|
||||||
|
}
|
||||||
|
return h.HashedName() + ": " + h.comment
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashedName returns the string representation of the hash
|
||||||
|
func (h Set) HashedName() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"nb-%s",
|
||||||
|
hex.EncodeToString(h.hash[:]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comment returns the comment of the set
|
||||||
|
func (h Set) Comment() string {
|
||||||
|
return h.comment
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||||
|
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||||
|
// sort for consistent naming
|
||||||
|
SortPrefixes(prefixes)
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
for _, src := range prefixes {
|
||||||
|
bytes, err := src.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to marshal prefix %s: %v", src, err)
|
||||||
|
}
|
||||||
|
hash.Write(bytes)
|
||||||
|
}
|
||||||
|
var set Set
|
||||||
|
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||||
|
|
||||||
|
return set
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDomainSet generates a unique name for an ipset based on the given domains.
|
||||||
|
func NewDomainSet(domains domain.List) Set {
|
||||||
|
slices.Sort(domains)
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
for _, d := range domains {
|
||||||
|
hash.Write([]byte(d.PunycodeString()))
|
||||||
|
}
|
||||||
|
set := Set{
|
||||||
|
comment: domains.SafeString(),
|
||||||
|
}
|
||||||
|
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||||
|
|
||||||
|
return set
|
||||||
|
}
|
||||||
@@ -27,7 +27,8 @@ const (
|
|||||||
// filter chains contains the rules that jump to the rules chains
|
// filter chains contains the rules that jump to the rules chains
|
||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||||
chainNamePrerouting = "netbird-rt-prerouting"
|
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||||
|
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||||
|
|
||||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||||
)
|
)
|
||||||
@@ -462,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||||
// netbird peer IP.
|
// netbird peer IP.
|
||||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||||
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
// Chain is created by route manager
|
||||||
Name: chainNamePrerouting,
|
// TODO: move creation to a common place
|
||||||
|
m.chainPrerouting = &nftables.Chain{
|
||||||
|
Name: chainNameManglePrerouting,
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
})
|
}
|
||||||
|
|
||||||
m.addFwmarkToForward(chainFwFilter)
|
m.addFwmarkToForward(chainFwFilter)
|
||||||
|
|
||||||
|
|||||||
@@ -135,17 +135,16 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if !destination.Addr().Is4() {
|
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
@@ -171,6 +170,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsStateful() bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -242,7 +245,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close closes the firewall manager
|
||||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -325,10 +328,16 @@ func (m *Manager) SetLogLevel(log.Level) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) EnableRouting() error {
|
func (m *Manager) EnableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("enable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
|
if err := m.router.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
return fmt.Errorf("disable IP forwarding: %w", err)
|
||||||
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -359,6 +368,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return m.router.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateSet updates the set with the given prefixes
|
||||||
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.UpdateSet(set, prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -3,7 +3,6 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
@@ -25,11 +24,8 @@ var ifaceMock = &iFaceMock{
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: netip.MustParseAddr("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -70,11 +66,11 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1").Unmap()
|
||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
rule, err := manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -109,8 +105,6 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
|
||||||
add := ipToAdd.Unmap()
|
|
||||||
expectedExprs2 := []expr.Any{
|
expectedExprs2 := []expr.Any{
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@@ -132,7 +126,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
&expr.Cmp{
|
&expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: add.AsSlice(),
|
Data: ip.AsSlice(),
|
||||||
},
|
},
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
DestRegister: 1,
|
DestRegister: 1,
|
||||||
@@ -173,11 +167,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
},
|
},
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: netip.MustParseAddr("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.96.0.0/16"),
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
|
||||||
Mask: net.IPv4Mask(255, 255, 255, 0),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -197,11 +188,11 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
ip := net.ParseIP("10.20.0.100")
|
ip := netip.MustParseAddr("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -282,14 +273,14 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
verifyIptablesOutput(t, stdout, stderr)
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := netip.MustParseAddr("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
_, err = manager.AddPeerFiltering(nil, ip.AsSlice(), fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
netip.MustParsePrefix("10.1.0.0/24"),
|
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
&fw.Port{Values: []uint16{443}},
|
&fw.Port{Values: []uint16{443}},
|
||||||
@@ -298,8 +289,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
require.NoError(t, err, "failed to add route filtering rule")
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|
||||||
pair := fw.RouterPair{
|
pair := fw.RouterPair{
|
||||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
err = manager.AddNatRule(pair)
|
err = manager.AddNatRule(pair)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
@@ -44,9 +43,14 @@ const (
|
|||||||
const refreshRulesMapError = "refresh rules map: %w"
|
const refreshRulesMapError = "refresh rules map: %w"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type setInput struct {
|
||||||
|
set firewall.Set
|
||||||
|
prefixes []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
conn *nftables.Conn
|
conn *nftables.Conn
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
@@ -54,7 +58,7 @@ type router struct {
|
|||||||
chains map[string]*nftables.Chain
|
chains map[string]*nftables.Chain
|
||||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||||
rules map[string]*nftables.Rule
|
rules map[string]*nftables.Rule
|
||||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
@@ -100,6 +104,10 @@ func (r *router) init(workTable *nftables.Table) error {
|
|||||||
return fmt.Errorf("create containers: %w", err)
|
return fmt.Errorf("create containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.setupDataPlaneMark(); err != nil {
|
||||||
|
log.Errorf("failed to set up data plane mark: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
|
|||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
return nil, fmt.Errorf("unable to list tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
@@ -196,15 +204,21 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Chain is created by acl manager
|
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||||
// TODO: move creation to a common place
|
Name: chainNameManglePostrouting,
|
||||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
Table: r.workTable,
|
||||||
Name: chainNamePrerouting,
|
Hooknum: nftables.ChainHookPostrouting,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameManglePrerouting,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
}
|
})
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
// Add the single NAT rule that matches on mark
|
||||||
if err := r.addPostroutingRules(); err != nil {
|
if err := r.addPostroutingRules(); err != nil {
|
||||||
@@ -220,7 +234,83 @@ func (r *router) createContainers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
return fmt.Errorf("initialize tables: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupDataPlaneMark configures the fwmark for the data plane
|
||||||
|
func (r *router) setupDataPlaneMark() error {
|
||||||
|
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||||
|
return errors.New("no mangle chains found")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctNew := getCtNewExprs()
|
||||||
|
preExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
preExprs = append(preExprs, ctNew...)
|
||||||
|
preExprs = append(preExprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||||
|
},
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
preNftRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameManglePrerouting],
|
||||||
|
Exprs: preExprs,
|
||||||
|
}
|
||||||
|
r.conn.AddRule(preNftRule)
|
||||||
|
|
||||||
|
postExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postExprs = append(postExprs, ctNew...)
|
||||||
|
postExprs = append(postExprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||||
|
},
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
postNftRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameManglePostrouting],
|
||||||
|
Exprs: postExprs,
|
||||||
|
}
|
||||||
|
r.conn.AddRule(postNftRule)
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -230,7 +320,7 @@ func (r *router) createContainers() error {
|
|||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
@@ -245,23 +335,29 @@ func (r *router) AddRouteFiltering(
|
|||||||
chain := r.chains[chainNameRoutingFw]
|
chain := r.chains[chainNameRoutingFw]
|
||||||
var exprs []expr.Any
|
var exprs []expr.Any
|
||||||
|
|
||||||
|
var source firewall.Network
|
||||||
switch {
|
switch {
|
||||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||||
// If it's 0.0.0.0/0, we don't need to add any source matching
|
// If it's 0.0.0.0/0, we don't need to add any source matching
|
||||||
case len(sources) == 1:
|
case len(sources) == 1:
|
||||||
// If there's only one source, we can use it directly
|
// If there's only one source, we can use it directly
|
||||||
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
|
source.Prefix = sources[0]
|
||||||
default:
|
default:
|
||||||
// If there are multiple sources, create or get an ipset
|
// If there are multiple sources, use a set
|
||||||
var err error
|
source.Set = firewall.NewPrefixSet(sources)
|
||||||
exprs, err = r.getIpSetExprs(sources, exprs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get ipset expressions: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle destination
|
sourceExp, err := r.applyNetwork(source, sources, true)
|
||||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, sourceExp...)
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, destExp...)
|
||||||
|
|
||||||
// Handle protocol
|
// Handle protocol
|
||||||
if proto != firewall.ProtocolALL {
|
if proto != firewall.ProtocolALL {
|
||||||
@@ -305,39 +401,27 @@ func (r *router) AddRouteFiltering(
|
|||||||
rule = r.conn.AddRule(rule)
|
rule = r.conn.AddRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return nil, fmt.Errorf(flushError, err)
|
return nil, fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[string(ruleKey)] = rule
|
r.rules[string(ruleKey)] = rule
|
||||||
|
|
||||||
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
||||||
setName := firewall.GenerateSetName(sources)
|
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
||||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
set: set,
|
||||||
|
prefixes: prefixes,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs = append(exprs,
|
return getIpSetExprs(ref, isSource)
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Lookup{
|
|
||||||
SourceRegister: 1,
|
|
||||||
SetName: ref.Out.Name,
|
|
||||||
SetID: ref.Out.ID,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return exprs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
@@ -356,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
setName := r.findSetNameInRule(nftRule)
|
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
return fmt.Errorf("delete: %w", err)
|
return fmt.Errorf("delete: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if setName != "" {
|
|
||||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
|
||||||
return fmt.Errorf("decrement ipset reference: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
||||||
// overlapping prefixes will result in an error, so we need to merge them
|
// overlapping prefixes will result in an error, so we need to merge them
|
||||||
sources = firewall.MergeIPRanges(sources)
|
prefixes := firewall.MergeIPRanges(input.prefixes)
|
||||||
|
|
||||||
set := &nftables.Set{
|
nfset := &nftables.Set{
|
||||||
Name: setName,
|
Name: setName,
|
||||||
|
Comment: input.set.Comment(),
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
// required for prefixes
|
// required for prefixes
|
||||||
Interval: true,
|
Interval: true,
|
||||||
KeyType: nftables.TypeIPAddr,
|
KeyType: nftables.TypeIPAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
elements := convertPrefixesToSet(prefixes)
|
||||||
|
if err := r.conn.AddSet(nfset, elements); err != nil {
|
||||||
|
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||||
|
|
||||||
|
return nfset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||||
var elements []nftables.SetElement
|
var elements []nftables.SetElement
|
||||||
for _, prefix := range sources {
|
for _, prefix := range prefixes {
|
||||||
// TODO: Implement IPv6 support
|
// TODO: Implement IPv6 support
|
||||||
if prefix.Addr().Is6() {
|
if prefix.Addr().Is6() {
|
||||||
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -407,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
|
|||||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
return elements
|
||||||
if err := r.conn.AddSet(set, elements); err != nil {
|
|
||||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
|
||||||
|
|
||||||
return set, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculateLastIP determines the last IP in a given prefix.
|
// calculateLastIP determines the last IP in a given prefix.
|
||||||
@@ -442,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
|
||||||
r.conn.DelSet(set)
|
r.conn.DelSet(nfset)
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
@@ -452,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
|
func (r *router) decrementSetCounter(rule *nftables.Rule) error {
|
||||||
|
sets := r.findSets(rule)
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, setName := range sets {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) findSets(rule *nftables.Rule) []string {
|
||||||
|
var sets []string
|
||||||
for _, e := range rule.Exprs {
|
for _, e := range rule.Exprs {
|
||||||
if lookup, ok := e.(*expr.Lookup); ok {
|
if lookup, ok := e.(*expr.Lookup); ok {
|
||||||
return lookup.SetName
|
sets = append(sets, lookup.SetName)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
return sets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||||
@@ -474,10 +573,6 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
|||||||
|
|
||||||
// AddNatRule appends a nftables rule pair to the nat chain
|
// AddNatRule appends a nftables rule pair to the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -500,7 +595,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
|
// TODO: rollback ipset counter
|
||||||
|
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -508,8 +604,15 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
op := expr.CmpOpEq
|
op := expr.CmpOpEq
|
||||||
if pair.Inverse {
|
if pair.Inverse {
|
||||||
@@ -517,26 +620,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
|
||||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
|
||||||
&expr.Ct{
|
|
||||||
Key: expr.CtKeySTATE,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 1,
|
|
||||||
DestRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
|
||||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
|
|
||||||
// interface matching
|
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
Key: expr.MetaKeyIIFNAME,
|
Key: expr.MetaKeyIIFNAME,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
@@ -547,6 +630,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
Data: ifname(r.wgIface.Name()),
|
Data: ifname(r.wgIface.Name()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||||
|
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||||
|
exprs = append(exprs, getCtNewExprs()...)
|
||||||
|
|
||||||
exprs = append(exprs, sourceExp...)
|
exprs = append(exprs, sourceExp...)
|
||||||
exprs = append(exprs, destExp...)
|
exprs = append(exprs, destExp...)
|
||||||
@@ -576,9 +662,11 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
// Ensure nat rules come first, so the mark can be overwritten.
|
||||||
|
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||||
|
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNamePrerouting],
|
Chain: r.chains[chainNameManglePrerouting],
|
||||||
Exprs: exprs,
|
Exprs: exprs,
|
||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
})
|
})
|
||||||
@@ -659,8 +747,15 @@ func (r *router) addPostroutingRules() error {
|
|||||||
|
|
||||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
&expr.Counter{},
|
&expr.Counter{},
|
||||||
@@ -669,7 +764,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
exprs = append(exprs, sourceExp...)
|
||||||
|
exprs = append(exprs, destExp...)
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
@@ -682,7 +778,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
Exprs: expression,
|
Exprs: exprs,
|
||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
@@ -697,11 +793,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
|
||||||
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -904,14 +1002,11 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
|||||||
|
|
||||||
// RemoveNatRule removes the prerouting mark rule
|
// RemoveNatRule removes the prerouting mark rule
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
|
||||||
log.Errorf("%v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if pair.Masquerade {
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -919,16 +1014,17 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove legacy routing rule: %w", err)
|
return fmt.Errorf("remove legacy routing rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
// TODO: rollback set counter
|
||||||
|
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -936,16 +1032,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
err := r.conn.DelRule(rule)
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -957,7 +1056,7 @@ func (r *router) refreshRulesMap() error {
|
|||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
return fmt.Errorf(" unable to list rules: %v", err)
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
@@ -1231,13 +1330,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
||||||
var offset uint32
|
if err != nil {
|
||||||
if source {
|
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||||
offset = 12 // src offset
|
}
|
||||||
} else {
|
|
||||||
offset = 16 // dst offset
|
elements := convertPrefixesToSet(prefixes)
|
||||||
|
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||||
|
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||||
|
func (r *router) applyNetwork(
|
||||||
|
network firewall.Network,
|
||||||
|
setPrefixes []netip.Prefix,
|
||||||
|
isSource bool,
|
||||||
|
) ([]expr.Any, error) {
|
||||||
|
if network.IsSet() {
|
||||||
|
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("source: %w", err)
|
||||||
|
}
|
||||||
|
return exprs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.IsPrefix() {
|
||||||
|
return applyPrefix(network.Prefix, isSource), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||||
|
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||||
|
// dst offset
|
||||||
|
offset := uint32(16)
|
||||||
|
if isSource {
|
||||||
|
// src offset
|
||||||
|
offset = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
ones := prefix.Bits()
|
ones := prefix.Bits()
|
||||||
@@ -1324,3 +1464,48 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
|
|
||||||
return exprs
|
return exprs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getCtNewExprs() []expr.Any {
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: 1,
|
||||||
|
DestRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||||
|
|
||||||
|
// dst offset
|
||||||
|
offset := uint32(16)
|
||||||
|
if isSource {
|
||||||
|
// src offset
|
||||||
|
offset = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: offset,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ref.Out.Name,
|
||||||
|
SetID: ref.Out.ID,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build CIDR matching expressions
|
// Build CIDR matching expressions
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||||
|
|
||||||
// Combine all expressions in the correct order
|
// Combine all expressions in the correct order
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := 0
|
found := 0
|
||||||
for _, chain := range rtr.chains {
|
for _, chain := range rtr.chains {
|
||||||
if chain.Name == chainNamePrerouting {
|
if chain.Name == chainNameManglePrerouting {
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
// Verify the rule was added
|
// Verify the rule was added
|
||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := false
|
found := false
|
||||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules")
|
require.NoError(t, err, "should list rules")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
|
|
||||||
// Verify the rule was removed
|
// Verify the rule was removed
|
||||||
found = false
|
found = false
|
||||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules after removal")
|
require.NoError(t, err, "should list rules after removal")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
setName := firewall.GenerateSetName(tt.sources)
|
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||||
set, err := r.createIpSet(setName, tt.sources)
|
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("Failed to create IP set: %v", err)
|
t.Logf("Failed to create IP set: %v", err)
|
||||||
printNftSets()
|
printNftSets()
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ var (
|
|||||||
Name: "Insert Forwarding IPV4 Rule",
|
Name: "Insert Forwarding IPV4 Rule",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -24,8 +24,8 @@ var (
|
|||||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -40,8 +40,8 @@ var (
|
|||||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -12,7 +12,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -22,7 +21,7 @@ const (
|
|||||||
firewallRuleName = "Netbird"
|
firewallRuleName = "Netbird"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||||
func (m *Manager) Close(*statemanager.Manager) error {
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -32,17 +31,14 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if fwder := m.forwarder.Load(); fwder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
|||||||
@@ -189,7 +189,7 @@ func (t *ICMPTracker) cleanup() {
|
|||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -23,11 +23,11 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TCPSyn uint8 = 0x02
|
|
||||||
TCPAck uint8 = 0x10
|
|
||||||
TCPFin uint8 = 0x01
|
TCPFin uint8 = 0x01
|
||||||
|
TCPSyn uint8 = 0x02
|
||||||
TCPRst uint8 = 0x04
|
TCPRst uint8 = 0x04
|
||||||
TCPPush uint8 = 0x08
|
TCPPush uint8 = 0x08
|
||||||
|
TCPAck uint8 = 0x10
|
||||||
TCPUrg uint8 = 0x20
|
TCPUrg uint8 = 0x20
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -41,7 +41,7 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// TCPState represents the state of a TCP connection
|
// TCPState represents the state of a TCP connection
|
||||||
type TCPState int
|
type TCPState int32
|
||||||
|
|
||||||
func (s TCPState) String() string {
|
func (s TCPState) String() string {
|
||||||
switch s {
|
switch s {
|
||||||
@@ -91,20 +91,23 @@ type TCPConnTrack struct {
|
|||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestPort uint16
|
DestPort uint16
|
||||||
State TCPState
|
state atomic.Int32
|
||||||
established atomic.Bool
|
|
||||||
tombstone atomic.Bool
|
tombstone atomic.Bool
|
||||||
sync.RWMutex
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEstablished safely checks if connection is established
|
// GetState safely retrieves the current state
|
||||||
func (t *TCPConnTrack) IsEstablished() bool {
|
func (t *TCPConnTrack) GetState() TCPState {
|
||||||
return t.established.Load()
|
return TCPState(t.state.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetEstablished safely sets the established state
|
// SetState safely updates the current state
|
||||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
func (t *TCPConnTrack) SetState(state TCPState) {
|
||||||
t.established.Store(state)
|
t.state.Store(int32(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareAndSwapState atomically changes the state from old to new if current == old
|
||||||
|
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
|
||||||
|
return t.state.CompareAndSwap(int32(old), int32(newState))
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsTombstone safely checks if the connection is marked for deletion
|
// IsTombstone safely checks if the connection is marked for deletion
|
||||||
@@ -125,13 +128,17 @@ type TCPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
|
waitTimeout time.Duration
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||||
|
waitTimeout := TimeWaitTimeout
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultTCPTimeout
|
timeout = DefaultTCPTimeout
|
||||||
|
} else {
|
||||||
|
waitTimeout = timeout / 45
|
||||||
}
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
@@ -142,6 +149,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
|||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
|
waitTimeout: waitTimeout,
|
||||||
flowLogger: flowLogger,
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -149,7 +157,7 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftyp
|
|||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
key := ConnKey{
|
key := ConnKey{
|
||||||
SrcIP: srcIP,
|
SrcIP: srcIP,
|
||||||
DstIP: dstIP,
|
DstIP: dstIP,
|
||||||
@@ -162,12 +170,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
|||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if exists {
|
if exists {
|
||||||
conn.Lock()
|
t.updateState(key, conn, flags, direction, size)
|
||||||
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
|
||||||
conn.Unlock()
|
|
||||||
|
|
||||||
conn.UpdateCounters(direction, size)
|
|
||||||
|
|
||||||
return key, true
|
return key, true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -175,7 +178,7 @@ func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound TCP connection
|
// TrackOutbound records an outbound TCP connection
|
||||||
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
|
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||||
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||||
// if (inverted direction) conn is not tracked, track this direction
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||||
@@ -183,14 +186,14 @@ func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort u
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackInbound processes an inbound TCP packet and updates connection state
|
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||||
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||||
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// track is the common implementation for tracking both inbound and outbound connections
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||||
if exists {
|
if exists || flags&TCPSyn == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -205,12 +208,11 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
|||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.established.Store(false)
|
|
||||||
conn.tombstone.Store(false)
|
conn.tombstone.Store(false)
|
||||||
|
conn.state.Store(int32(TCPStateNew))
|
||||||
|
|
||||||
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||||
t.updateState(key, conn, flags, direction == nftypes.Egress)
|
t.updateState(key, conn, flags, direction, size)
|
||||||
conn.UpdateCounters(direction, size)
|
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
t.connections[key] = conn
|
t.connections[key] = conn
|
||||||
@@ -220,7 +222,7 @@ func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, d
|
|||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
|
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||||
key := ConnKey{
|
key := ConnKey{
|
||||||
SrcIP: dstIP,
|
SrcIP: dstIP,
|
||||||
DstIP: srcIP,
|
DstIP: srcIP,
|
||||||
@@ -232,134 +234,125 @@ func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort
|
|||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if !exists || conn.IsTombstone() {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle RST flag specially - it always causes transition to closed
|
currentState := conn.GetState()
|
||||||
if flags&TCPRst != 0 {
|
if !t.isValidStateForFlags(currentState, flags) {
|
||||||
return t.handleRst(key, conn, size)
|
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||||
}
|
// allow all flags for established for now
|
||||||
|
if currentState == TCPStateEstablished {
|
||||||
conn.Lock()
|
|
||||||
t.updateState(key, conn, flags, false)
|
|
||||||
isEstablished := conn.IsEstablished()
|
|
||||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
|
||||||
conn.Unlock()
|
|
||||||
conn.UpdateCounters(nftypes.Ingress, size)
|
|
||||||
|
|
||||||
return isEstablished || isValidState
|
|
||||||
}
|
|
||||||
|
|
||||||
func (t *TCPTracker) handleRst(key ConnKey, conn *TCPConnTrack, size int) bool {
|
|
||||||
if conn.IsTombstone() {
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
conn.Lock()
|
t.updateState(key, conn, flags, nftypes.Ingress, size)
|
||||||
conn.SetTombstone()
|
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetEstablished(false)
|
|
||||||
conn.Unlock()
|
|
||||||
conn.UpdateCounters(nftypes.Ingress, size)
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection reset: %s", key)
|
|
||||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateState updates the TCP connection state based on flags
|
// updateState updates the TCP connection state based on flags
|
||||||
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||||
conn.UpdateLastSeen()
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(packetDir, size)
|
||||||
|
|
||||||
state := conn.State
|
currentState := conn.GetState()
|
||||||
defer func() {
|
|
||||||
if state != conn.State {
|
if flags&TCPRst != 0 {
|
||||||
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
|
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||||
|
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
return
|
||||||
}
|
}
|
||||||
}()
|
|
||||||
|
|
||||||
switch state {
|
var newState TCPState
|
||||||
|
switch currentState {
|
||||||
case TCPStateNew:
|
case TCPStateNew:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||||
conn.State = TCPStateSynSent
|
if conn.Direction == nftypes.Egress {
|
||||||
|
newState = TCPStateSynSent
|
||||||
|
} else {
|
||||||
|
newState = TCPStateSynReceived
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateSynSent:
|
case TCPStateSynSent:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||||
if isOutbound {
|
if packetDir != conn.Direction {
|
||||||
conn.State = TCPStateEstablished
|
newState = TCPStateEstablished
|
||||||
conn.SetEstablished(true)
|
|
||||||
} else {
|
} else {
|
||||||
// Simultaneous open
|
// Simultaneous open
|
||||||
conn.State = TCPStateSynReceived
|
newState = TCPStateSynReceived
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateSynReceived:
|
case TCPStateSynReceived:
|
||||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||||
conn.State = TCPStateEstablished
|
if packetDir == conn.Direction {
|
||||||
conn.SetEstablished(true)
|
newState = TCPStateEstablished
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateEstablished:
|
case TCPStateEstablished:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
if isOutbound {
|
if packetDir == conn.Direction {
|
||||||
conn.State = TCPStateFinWait1
|
newState = TCPStateFinWait1
|
||||||
} else {
|
} else {
|
||||||
conn.State = TCPStateCloseWait
|
newState = TCPStateCloseWait
|
||||||
}
|
}
|
||||||
conn.SetEstablished(false)
|
|
||||||
} else if flags&TCPRst != 0 {
|
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetTombstone()
|
|
||||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait1:
|
case TCPStateFinWait1:
|
||||||
|
if packetDir != conn.Direction {
|
||||||
switch {
|
switch {
|
||||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||||
conn.State = TCPStateClosing
|
newState = TCPStateClosing
|
||||||
case flags&TCPFin != 0:
|
case flags&TCPFin != 0:
|
||||||
conn.State = TCPStateFinWait2
|
newState = TCPStateClosing
|
||||||
case flags&TCPAck != 0:
|
case flags&TCPAck != 0:
|
||||||
conn.State = TCPStateFinWait2
|
newState = TCPStateFinWait2
|
||||||
case flags&TCPRst != 0:
|
}
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetTombstone()
|
|
||||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait2:
|
case TCPStateFinWait2:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
newState = TCPStateTimeWait
|
||||||
|
|
||||||
t.logger.Trace("TCP connection %s completed", key)
|
|
||||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateClosing:
|
case TCPStateClosing:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
newState = TCPStateTimeWait
|
||||||
// Keep established = false from previous state
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
|
||||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
conn.State = TCPStateLastAck
|
newState = TCPStateLastAck
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateClosed
|
newState = TCPStateClosed
|
||||||
conn.SetTombstone()
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// Send close event for gracefully closed connections
|
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||||
|
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||||
|
|
||||||
|
switch newState {
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
|
||||||
|
case TCPStateClosed:
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
t.logger.Trace("TCP connection %s closed gracefully", key)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -369,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
|||||||
if !isValidFlagCombination(flags) {
|
if !isValidFlagCombination(flags) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
if state == TCPStateSynSent {
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
switch state {
|
switch state {
|
||||||
case TCPStateNew:
|
case TCPStateNew:
|
||||||
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||||
case TCPStateSynSent:
|
case TCPStateSynSent:
|
||||||
|
// TODO: support simultaneous open
|
||||||
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||||
case TCPStateSynReceived:
|
case TCPStateSynReceived:
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
case TCPStateEstablished:
|
case TCPStateEstablished:
|
||||||
if flags&TCPRst != 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
case TCPStateFinWait1:
|
case TCPStateFinWait1:
|
||||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
@@ -397,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
|||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
case TCPStateClosed:
|
case TCPStateClosed:
|
||||||
// Accept retransmitted ACKs in closed state
|
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
|
||||||
// This is important because the final ACK might be lost
|
|
||||||
// and the peer will retransmit their FIN-ACK
|
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -430,23 +425,24 @@ func (t *TCPTracker) cleanup() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
switch {
|
currentState := conn.GetState()
|
||||||
case conn.State == TCPStateTimeWait:
|
switch currentState {
|
||||||
timeout = TimeWaitTimeout
|
case TCPStateTimeWait:
|
||||||
case conn.IsEstablished():
|
timeout = t.waitTimeout
|
||||||
|
case TCPStateEstablished:
|
||||||
timeout = t.timeout
|
timeout = t.timeout
|
||||||
default:
|
default:
|
||||||
timeout = TCPHandshakeTimeout
|
timeout = TCPHandshakeTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(timeout) {
|
if conn.timeoutExceeded(timeout) {
|
||||||
// Return IPs to pool
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
|
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||||
|
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
|
||||||
// event already handled by state change
|
// event already handled by state change
|
||||||
if conn.State != TCPStateTimeWait {
|
if currentState != TCPStateTimeWait {
|
||||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
if i%2 == 0 {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
} else {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark connection cleanup
|
||||||
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Pre-populate with expired connections
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for connections to expire
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.cleanup()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -5,6 +5,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -124,9 +125,6 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
// Receive RST
|
// Receive RST
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
require.True(t, valid, "RST should be allowed for established connection")
|
require.True(t, valid, "RST should be allowed for established connection")
|
||||||
|
|
||||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
|
||||||
// The connection will be cleaned up by timeout
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -217,97 +215,446 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
conn := tracker.connections[key]
|
conn := tracker.connections[key]
|
||||||
if tt.wantValid {
|
if tt.wantValid {
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
require.Equal(t, TCPStateClosed, conn.State)
|
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||||
require.False(t, conn.IsEstablished())
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTCPRetransmissions(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Test SYN retransmission
|
||||||
|
t.Run("SYN Retransmission", func(t *testing.T) {
|
||||||
|
// Initial SYN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Retransmit SYN (should not affect the state machine)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Verify we're still in SYN-SENT state
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||||
|
|
||||||
|
// Complete the handshake
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// Verify we're in ESTABLISHED state
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test ACK retransmission in established state
|
||||||
|
t.Run("ACK Retransmission", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Retransmit ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// State should remain ESTABLISHED
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test FIN retransmission
|
||||||
|
t.Run("FIN Retransmission", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Send FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Retransmit FIN (should not change state)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Receive ACK for FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPDataTransfer(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Data Transfer", func(t *testing.T) {
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Send data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||||
|
|
||||||
|
// Receive ACK for data
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Receive data
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Send ACK for received data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
|
|
||||||
|
// State should remain ESTABLISHED
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
|
||||||
|
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
|
||||||
|
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
|
||||||
|
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPHalfClosedConnections(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Test half-closed connection: local end closes, remote end continues sending data
|
||||||
|
t.Run("Local Close, Remote Data", func(t *testing.T) {
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Send FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Receive ACK for FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
|
||||||
|
// Remote end can still send data
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// We can still ACK their data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// Receive FIN from remote end
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Send final ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// State should remain TIME-WAIT (waiting for possible retransmissions)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test half-closed connection: remote end closes, local end continues sending data
|
||||||
|
t.Run("Remote Close, Local Data", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Receive FIN from remote
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||||
|
|
||||||
|
// We can still send data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||||
|
|
||||||
|
// Remote can still ACK our data
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Send our FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||||
|
|
||||||
|
// Receive final ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPAbnormalSequences(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Test handling of unsolicited RST in various states
|
||||||
|
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
|
||||||
|
// Send SYN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Receive unsolicited RST (without proper ACK)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
|
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
|
||||||
|
|
||||||
|
// Receive RST with proper ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||||
|
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
|
// Create tracker with a very short timeout for testing
|
||||||
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Connection Timeout", func(t *testing.T) {
|
||||||
|
// Establish a connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Wait for the connection to timeout
|
||||||
|
time.Sleep(2 * shortTimeout)
|
||||||
|
|
||||||
|
// Force cleanup
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
// Connection should be removed
|
||||||
|
_, exists := tracker.connections[key]
|
||||||
|
require.False(t, exists, "Connection should be removed after timeout")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Complete the connection close to enter TIME_WAIT
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// TIME_WAIT should have its own timeout value (usually 2*MSL)
|
||||||
|
// For the test, we're using a short timeout
|
||||||
|
time.Sleep(2 * shortTimeout)
|
||||||
|
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
// Connection should be removed
|
||||||
|
_, exists := tracker.connections[key]
|
||||||
|
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSynFlood(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
basePort := uint16(10000)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Create a large number of SYN packets to simulate a SYN flood
|
||||||
|
for i := uint16(0); i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we're tracking all connections
|
||||||
|
require.Equal(t, 1000, len(tracker.connections))
|
||||||
|
|
||||||
|
// Now simulate SYN timeout
|
||||||
|
var oldConns int
|
||||||
|
tracker.mutex.Lock()
|
||||||
|
for _, conn := range tracker.connections {
|
||||||
|
if conn.GetState() == TCPStateSynSent {
|
||||||
|
// Make the connection appear old
|
||||||
|
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
|
||||||
|
oldConns++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracker.mutex.Unlock()
|
||||||
|
require.Equal(t, 1000, oldConns)
|
||||||
|
|
||||||
|
// Run cleanup
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
// Check that stale connections were cleaned up
|
||||||
|
require.Equal(t, 0, len(tracker.connections))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
clientIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
serverIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
clientPort := uint16(12345)
|
||||||
|
serverPort := uint16(80)
|
||||||
|
|
||||||
|
// 1. Client sends SYN (we receive it as inbound)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: clientIP,
|
||||||
|
DstIP: serverIP,
|
||||||
|
SrcPort: clientPort,
|
||||||
|
DstPort: serverPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker.mutex.RLock()
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
tracker.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
|
||||||
|
|
||||||
|
// 2. Server sends SYN-ACK response
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
|
||||||
|
// 3. Client sends ACK to complete handshake
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||||
|
|
||||||
|
// 4. Test data transfer
|
||||||
|
// Client sends data
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||||
|
|
||||||
|
// Server sends ACK for data
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||||
|
|
||||||
|
// Server sends data
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||||
|
|
||||||
|
// Client sends ACK for data
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||||
|
|
||||||
|
// Verify state and counters
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
|
||||||
|
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
|
||||||
|
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
|
||||||
|
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
|
||||||
|
}
|
||||||
|
|
||||||
// Helper to establish a TCP connection
|
// Helper to establish a TCP connection
|
||||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
require.True(t, valid, "SYN-ACK should be allowed")
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkTCPTracker(b *testing.B) {
|
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
|
||||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
|
||||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
|
||||||
|
|
||||||
// Pre-populate some connections
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
|
||||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
|
||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
|
||||||
i := 0
|
|
||||||
for pb.Next() {
|
|
||||||
if i%2 == 0 {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
|
||||||
} else {
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Benchmark connection cleanup
|
|
||||||
func BenchmarkCleanup(b *testing.B) {
|
|
||||||
b.Run("TCPCleanup", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
|
||||||
srcIP := netip.MustParseAddr("192.168.1.1")
|
|
||||||
dstIP := netip.MustParseAddr("192.168.1.2")
|
|
||||||
for i := 0; i < 10000; i++ {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for connections to expire
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tracker.cleanup()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -165,7 +165,7 @@ func (t *UDPTracker) cleanup() {
|
|||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||||
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,7 +4,9 @@ import (
|
|||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"gvisor.dev/gvisor/pkg/buffer"
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
@@ -17,6 +19,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
@@ -31,12 +34,14 @@ const (
|
|||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
|
// ruleIdMap is used to store the rule ID for a given connection
|
||||||
|
ruleIdMap sync.Map
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ip net.IP
|
ip tcpip.Address
|
||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -66,12 +71,11 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ones, _ := iface.Address().Network.Mask.Size()
|
|
||||||
protoAddr := tcpip.ProtocolAddress{
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
Protocol: ipv4.ProtocolNumber,
|
Protocol: ipv4.ProtocolNumber,
|
||||||
AddressWithPrefix: tcpip.AddressWithPrefix{
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
Address: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
PrefixLen: ones,
|
PrefixLen: iface.Address().Network.Bits(),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -111,7 +115,7 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.Flow
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
ip: iface.Address().IP,
|
ip: tcpip.AddrFromSlice(iface.Address().IP.AsSlice()),
|
||||||
}
|
}
|
||||||
|
|
||||||
receiveWindow := defaultReceiveWindow
|
receiveWindow := defaultReceiveWindow
|
||||||
@@ -162,8 +166,39 @@ func (f *Forwarder) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
if f.netstack && f.ip.Equal(addr) {
|
||||||
return net.IPv4(127, 0, 0, 1)
|
return net.IPv4(127, 0, 0, 1)
|
||||||
}
|
}
|
||||||
return addr.AsSlice()
|
return addr.AsSlice()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) RegisterRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, ruleID []byte) {
|
||||||
|
key := buildKey(srcIP, dstIP, srcPort, dstPort)
|
||||||
|
f.ruleIdMap.LoadOrStore(key, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) getRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
|
if value, ok := f.ruleIdMap.Load(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||||
|
return value.([]byte), true
|
||||||
|
} else if value, ok := f.ruleIdMap.Load(buildKey(dstIP, srcIP, dstPort, srcPort)); ok {
|
||||||
|
return value.([]byte), true
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) DeleteRuleID(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||||
|
if _, ok := f.ruleIdMap.LoadAndDelete(buildKey(srcIP, dstIP, srcPort, dstPort)); ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.ruleIdMap.LoadAndDelete(buildKey(dstIP, srcIP, dstPort, srcPort))
|
||||||
|
}
|
||||||
|
|
||||||
|
func buildKey(srcIP, dstIP netip.Addr, srcPort, dstPort uint16) conntrack.ConnKey {
|
||||||
|
return conntrack.ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -25,7 +25,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
}
|
}
|
||||||
|
|
||||||
flowID := uuid.New()
|
flowID := uuid.New()
|
||||||
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode, 0, 0)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
@@ -34,14 +34,14 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
// TODO: support non-root
|
// TODO: support non-root
|
||||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
f.logger.Error("forwarder: Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||||
|
|
||||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := conn.Close(); err != nil {
|
if err := conn.Close(); err != nil {
|
||||||
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
f.logger.Debug("forwarder: Failed to close ICMP socket: %v", err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -52,36 +52,37 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
payload := fullPacket.AsSlice()
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
if _, err = conn.WriteTo(payload, dst); err != nil {
|
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
f.logger.Error("forwarder: Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
f.logger.Trace("forwarder: Forwarded ICMP packet %v type %v code %v",
|
||||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
// For Echo Requests, send and handle response
|
// For Echo Requests, send and handle response
|
||||||
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||||
f.handleEchoResponse(icmpHdr, conn, id)
|
rxBytes := pkt.Size()
|
||||||
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
txBytes := f.handleEchoResponse(icmpHdr, conn, id)
|
||||||
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode, uint64(rxBytes), uint64(txBytes))
|
||||||
}
|
}
|
||||||
|
|
||||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) int {
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
f.logger.Error("forwarder: Failed to set read deadline for ICMP response: %v", err)
|
||||||
return
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, f.endpoint.mtu)
|
response := make([]byte, f.endpoint.mtu)
|
||||||
n, _, err := conn.ReadFrom(response)
|
n, _, err := conn.ReadFrom(response)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if !isTimeout(err) {
|
if !isTimeout(err) {
|
||||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
f.logger.Error("forwarder: Failed to read ICMP response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
@@ -100,28 +101,54 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketCon
|
|||||||
fullPacket = append(fullPacket, response[:n]...)
|
fullPacket = append(fullPacket, response[:n]...)
|
||||||
|
|
||||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
f.logger.Error("forwarder: Failed to inject ICMP response: %v", err)
|
||||||
|
|
||||||
return
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
f.logger.Trace("forwarder: Forwarded ICMP echo reply for %v type %v code %v",
|
||||||
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
|
return len(fullPacket)
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendICMPEvent stores flow events for ICMP packets
|
// sendICMPEvent stores flow events for ICMP packets
|
||||||
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8, rxBytes, txBytes uint64) {
|
||||||
f.flowLogger.StoreEvent(nftypes.EventFields{
|
var rxPackets, txPackets uint64
|
||||||
|
if rxBytes > 0 {
|
||||||
|
rxPackets = 1
|
||||||
|
}
|
||||||
|
if txBytes > 0 {
|
||||||
|
txPackets = 1
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||||
|
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||||
|
|
||||||
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.ICMP,
|
Protocol: nftypes.ICMP,
|
||||||
// TODO: handle ipv6
|
// TODO: handle ipv6
|
||||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
SourceIP: srcIp,
|
||||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
DestIP: dstIp,
|
||||||
ICMPType: icmpType,
|
ICMPType: icmpType,
|
||||||
ICMPCode: icmpCode,
|
ICMPCode: icmpCode,
|
||||||
|
|
||||||
// TODO: get packets/bytes
|
RxBytes: rxBytes,
|
||||||
})
|
TxBytes: txBytes,
|
||||||
|
RxPackets: rxPackets,
|
||||||
|
TxPackets: txPackets,
|
||||||
|
}
|
||||||
|
|
||||||
|
if typ == nftypes.TypeStart {
|
||||||
|
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||||
|
fields.RuleID = ruleId
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -6,8 +6,10 @@ import (
|
|||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
|
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
@@ -23,11 +25,11 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
|
|
||||||
flowID := uuid.New()
|
flowID := uuid.New()
|
||||||
|
|
||||||
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
f.sendTCPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||||
var success bool
|
var success bool
|
||||||
defer func() {
|
defer func() {
|
||||||
if !success {
|
if !success {
|
||||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -65,67 +67,97 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||||
defer func() {
|
|
||||||
if err := inConn.Close(); err != nil {
|
|
||||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
|
||||||
}
|
|
||||||
if err := outConn.Close(); err != nil {
|
|
||||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
|
||||||
}
|
|
||||||
ep.Close()
|
|
||||||
|
|
||||||
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
|
||||||
}()
|
|
||||||
|
|
||||||
// Create context for managing the proxy goroutines
|
|
||||||
ctx, cancel := context.WithCancel(f.ctx)
|
ctx, cancel := context.WithCancel(f.ctx)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
// Close connections and endpoint.
|
||||||
|
if err := inConn.Close(); err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
var (
|
||||||
|
bytesFromInToOut int64 // bytes from client to server (tx for client)
|
||||||
|
bytesFromOutToIn int64 // bytes from server to client (rx for client)
|
||||||
|
errInToOut error
|
||||||
|
errOutToIn error
|
||||||
|
)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(outConn, inConn)
|
bytesFromInToOut, errInToOut = io.Copy(outConn, inConn)
|
||||||
errChan <- err
|
cancel()
|
||||||
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
_, err := io.Copy(inConn, outConn)
|
|
||||||
errChan <- err
|
bytesFromOutToIn, errOutToIn = io.Copy(inConn, outConn)
|
||||||
|
cancel()
|
||||||
|
wg.Done()
|
||||||
}()
|
}()
|
||||||
|
|
||||||
select {
|
wg.Wait()
|
||||||
case <-ctx.Done():
|
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
if errInToOut != nil {
|
||||||
return
|
if !isClosedError(errInToOut) {
|
||||||
case err := <-errChan:
|
f.logger.Error("proxyTCP: copy error (in -> out) for %s: %v", epID(id), errInToOut)
|
||||||
if err != nil && !isClosedError(err) {
|
|
||||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
}
|
||||||
return
|
if errOutToIn != nil {
|
||||||
|
if !isClosedError(errOutToIn) {
|
||||||
|
f.logger.Error("proxyTCP: copy error (out -> in) for %s: %v", epID(id), errOutToIn)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
var rxPackets, txPackets uint64
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
rxPackets = tcpStats.SegmentsSent.Value()
|
||||||
|
txPackets = tcpStats.SegmentsReceived.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: Removed TCP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, bytesFromOutToIn, txPackets, bytesFromInToOut)
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, uint64(bytesFromOutToIn), uint64(bytesFromInToOut), rxPackets, txPackets)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||||
|
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||||
|
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.TCP,
|
Protocol: nftypes.TCP,
|
||||||
// TODO: handle ipv6
|
// TODO: handle ipv6
|
||||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
SourceIP: srcIp,
|
||||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
DestIP: dstIp,
|
||||||
SourcePort: id.RemotePort,
|
SourcePort: id.RemotePort,
|
||||||
DestPort: id.LocalPort,
|
DestPort: id.LocalPort,
|
||||||
|
RxBytes: rxBytes,
|
||||||
|
TxBytes: txBytes,
|
||||||
|
RxPackets: rxPackets,
|
||||||
|
TxPackets: txPackets,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ep != nil {
|
if typ == nftypes.TypeStart {
|
||||||
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||||
// fields are flipped since this is the in conn
|
fields.RuleID = ruleId
|
||||||
// TODO: get bytes
|
|
||||||
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
|
||||||
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.flowLogger.StoreEvent(fields)
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
|||||||
@@ -149,11 +149,11 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
|
|
||||||
flowID := uuid.New()
|
flowID := uuid.New()
|
||||||
|
|
||||||
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
f.sendUDPEvent(nftypes.TypeStart, flowID, id, 0, 0, 0, 0)
|
||||||
var success bool
|
var success bool
|
||||||
defer func() {
|
defer func() {
|
||||||
if !success {
|
if !success {
|
||||||
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, 0, 0, 0, 0)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
@@ -199,7 +199,6 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.udpForwarder.conns[id] = pConn
|
f.udpForwarder.conns[id] = pConn
|
||||||
@@ -212,68 +211,94 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
defer func() {
|
|
||||||
|
ctx, cancel := context.WithCancel(f.ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
<-ctx.Done()
|
||||||
|
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := pConn.conn.Close(); err != nil {
|
if err := pConn.conn.Close(); err != nil && !isClosedError(err) {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := pConn.outConn.Close(); err != nil {
|
if err := pConn.outConn.Close(); err != nil && !isClosedError(err) {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ep.Close()
|
ep.Close()
|
||||||
|
}()
|
||||||
|
|
||||||
|
var wg sync.WaitGroup
|
||||||
|
wg.Add(2)
|
||||||
|
|
||||||
|
var txBytes, rxBytes int64
|
||||||
|
var outboundErr, inboundErr error
|
||||||
|
|
||||||
|
// outbound->inbound: copy from pConn.conn to pConn.outConn
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
txBytes, outboundErr = pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
// inbound->outbound: copy from pConn.outConn to pConn.conn
|
||||||
|
go func() {
|
||||||
|
defer wg.Done()
|
||||||
|
rxBytes, inboundErr = pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
wg.Wait()
|
||||||
|
|
||||||
|
if outboundErr != nil && !isClosedError(outboundErr) {
|
||||||
|
f.logger.Error("proxyUDP: copy error (outbound->inbound) for %s: %v", epID(id), outboundErr)
|
||||||
|
}
|
||||||
|
if inboundErr != nil && !isClosedError(inboundErr) {
|
||||||
|
f.logger.Error("proxyUDP: copy error (inbound->outbound) for %s: %v", epID(id), inboundErr)
|
||||||
|
}
|
||||||
|
|
||||||
|
var rxPackets, txPackets uint64
|
||||||
|
if udpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
rxPackets = udpStats.PacketsSent.Value()
|
||||||
|
txPackets = udpStats.PacketsReceived.Value()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: Removed UDP connection %s [in: %d Pkts/%d B, out: %d Pkts/%d B]", epID(id), rxPackets, rxBytes, txPackets, txBytes)
|
||||||
|
|
||||||
f.udpForwarder.Lock()
|
f.udpForwarder.Lock()
|
||||||
delete(f.udpForwarder.conns, id)
|
delete(f.udpForwarder.conns, id)
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, uint64(rxBytes), uint64(txBytes), rxPackets, txPackets)
|
||||||
}()
|
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
|
||||||
}()
|
|
||||||
|
|
||||||
go func() {
|
|
||||||
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
|
||||||
}()
|
|
||||||
|
|
||||||
select {
|
|
||||||
case <-ctx.Done():
|
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
|
||||||
return
|
|
||||||
case err := <-errChan:
|
|
||||||
if err != nil && !isClosedError(err) {
|
|
||||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
|
||||||
}
|
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// sendUDPEvent stores flow events for UDP connections
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, rxBytes, txBytes, rxPackets, txPackets uint64) {
|
||||||
|
srcIp := netip.AddrFrom4(id.RemoteAddress.As4())
|
||||||
|
dstIp := netip.AddrFrom4(id.LocalAddress.As4())
|
||||||
|
|
||||||
fields := nftypes.EventFields{
|
fields := nftypes.EventFields{
|
||||||
FlowID: flowID,
|
FlowID: flowID,
|
||||||
Type: typ,
|
Type: typ,
|
||||||
Direction: nftypes.Ingress,
|
Direction: nftypes.Ingress,
|
||||||
Protocol: nftypes.UDP,
|
Protocol: nftypes.UDP,
|
||||||
// TODO: handle ipv6
|
// TODO: handle ipv6
|
||||||
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
SourceIP: srcIp,
|
||||||
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
DestIP: dstIp,
|
||||||
SourcePort: id.RemotePort,
|
SourcePort: id.RemotePort,
|
||||||
DestPort: id.LocalPort,
|
DestPort: id.LocalPort,
|
||||||
|
RxBytes: rxBytes,
|
||||||
|
TxBytes: txBytes,
|
||||||
|
RxPackets: rxPackets,
|
||||||
|
TxPackets: txPackets,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ep != nil {
|
if typ == nftypes.TypeStart {
|
||||||
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
if ruleId, ok := f.getRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort); ok {
|
||||||
// fields are flipped since this is the in conn
|
fields.RuleID = ruleId
|
||||||
// TODO: get bytes
|
|
||||||
fields.RxPackets = tcpStats.PacketsSent.Value()
|
|
||||||
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
|
||||||
}
|
}
|
||||||
|
} else {
|
||||||
|
f.DeleteRuleID(srcIp, dstIp, id.RemotePort, id.LocalPort)
|
||||||
}
|
}
|
||||||
|
|
||||||
f.flowLogger.StoreEvent(fields)
|
f.flowLogger.StoreEvent(fields)
|
||||||
@@ -288,18 +313,20 @@ func (c *udpPacketConn) getIdleDuration() time.Duration {
|
|||||||
return time.Since(lastSeen)
|
return time.Since(lastSeen)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
// copy reads from src and writes to dst.
|
||||||
|
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) (int64, error) {
|
||||||
bufp := bufPool.Get().(*[]byte)
|
bufp := bufPool.Get().(*[]byte)
|
||||||
defer bufPool.Put(bufp)
|
defer bufPool.Put(bufp)
|
||||||
buffer := *bufp
|
buffer := *bufp
|
||||||
|
var totalBytes int64 = 0
|
||||||
|
|
||||||
for {
|
for {
|
||||||
if ctx.Err() != nil {
|
if ctx.Err() != nil {
|
||||||
return ctx.Err()
|
return totalBytes, ctx.Err()
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
return fmt.Errorf("set read deadline: %w", err)
|
return totalBytes, fmt.Errorf("set read deadline: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
n, err := src.Read(buffer)
|
n, err := src.Read(buffer)
|
||||||
@@ -307,14 +334,15 @@ func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bu
|
|||||||
if isTimeout(err) {
|
if isTimeout(err) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return fmt.Errorf("read from %s: %w", direction, err)
|
return totalBytes, fmt.Errorf("read from %s: %w", direction, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
_, err = dst.Write(buffer[:n])
|
nWritten, err := dst.Write(buffer[:n])
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("write to %s: %w", direction, err)
|
return totalBytes, fmt.Errorf("write to %s: %w", direction, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
totalBytes += int64(nWritten)
|
||||||
c.updateLastSeen()
|
c.updateLastSeen()
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -14,8 +14,13 @@ import (
|
|||||||
type localIPManager struct {
|
type localIPManager struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|
||||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
// fixed-size high array for upper byte of a IPv4 address
|
||||||
ipv4Bitmap [1 << 16]uint32
|
ipv4Bitmap [256]*ipv4LowBitmap
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||||
|
type ipv4LowBitmap struct {
|
||||||
|
bitmap [8192]uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLocalIPManager() *localIPManager {
|
func newLocalIPManager() *localIPManager {
|
||||||
@@ -27,35 +32,61 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
if ipv4 == nil {
|
if ipv4 == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
high := uint16(ipv4[0])
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
|
||||||
|
index := low / 32
|
||||||
|
bit := low % 32
|
||||||
|
|
||||||
|
if m.ipv4Bitmap[high] == nil {
|
||||||
|
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) setBitInBitmap(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
|
if !ip.Is4() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
ipv4 := ip.AsSlice()
|
||||||
|
|
||||||
|
high := uint16(ipv4[0])
|
||||||
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
|
|
||||||
|
if bitmap[high] == nil {
|
||||||
|
bitmap[high] = &ipv4LowBitmap{}
|
||||||
|
}
|
||||||
|
|
||||||
|
index := low / 32
|
||||||
|
bit := low % 32
|
||||||
|
bitmap[high].bitmap[index] |= 1 << bit
|
||||||
|
|
||||||
|
if _, exists := ipv4Set[ip]; !exists {
|
||||||
|
ipv4Set[ip] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ip)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||||
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
high := uint16(ip[0])
|
||||||
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
|
||||||
|
if m.ipv4Bitmap[high] == nil {
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
index := low / 32
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
bit := low % 32
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
|
||||||
if int(high) >= len(*newIPv4Bitmap) {
|
|
||||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
|
||||||
}
|
|
||||||
ipStr := ip.String()
|
|
||||||
if _, exists := ipv4Set[ipStr]; !exists {
|
|
||||||
ipv4Set[ipStr] = struct{}{}
|
|
||||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
|
||||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processIP(ip netip.Addr, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) error {
|
||||||
|
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[netip.Addr]struct{}, ipv4Addresses *[]netip.Addr) {
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
@@ -73,7 +104,13 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
addr, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
log.Warnf("invalid IP address %s in interface %s", ip.String(), iface.Name)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.processIP(addr.Unmap(), bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
log.Debugf("process IP failed: %v", err)
|
log.Debugf("process IP failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -86,14 +123,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var newIPv4Bitmap [1 << 16]uint32
|
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||||
ipv4Set := make(map[string]struct{})
|
ipv4Set := make(map[netip.Addr]struct{})
|
||||||
var ipv4Addresses []string
|
var ipv4Addresses []netip.Addr
|
||||||
|
|
||||||
// 127.0.0.0/8
|
// 127.0.0.0/8
|
||||||
high := uint16(127) << 8
|
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||||
for i := uint16(0); i < 256; i++ {
|
for i := 0; i < 8192; i++ {
|
||||||
newIPv4Bitmap[high|i] = 0xffffffff
|
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||||
}
|
}
|
||||||
|
|
||||||
if iface != nil {
|
if iface != nil {
|
||||||
@@ -120,12 +157,12 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||||
|
if !ip.Is4() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
if ip.Is4() {
|
|
||||||
return m.checkBitmapBit(ip.AsSlice())
|
return m.checkBitmapBit(ip.AsSlice())
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -20,11 +20,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range",
|
name: "Localhost range",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -32,11 +29,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost standard address",
|
name: "Localhost standard address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -44,11 +38,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Localhost range edge",
|
name: "Localhost range edge",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -56,11 +47,8 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP matches",
|
name: "Local IP matches",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
@@ -68,23 +56,26 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "Local IP doesn't match",
|
name: "Local IP doesn't match",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("192.168.1.1"),
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("192.168.1.0"),
|
|
||||||
Mask: net.CIDRMask(24, 32),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP doesn't match - addresses 32 apart",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
name: "IPv6 address",
|
name: "IPv6 address",
|
||||||
setupAddr: wgaddr.Address{
|
setupAddr: wgaddr.Address{
|
||||||
IP: net.ParseIP("fe80::1"),
|
IP: netip.MustParseAddr("fe80::1"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
IP: net.ParseIP("fe80::"),
|
|
||||||
Mask: net.CIDRMask(64, 128),
|
|
||||||
},
|
|
||||||
},
|
},
|
||||||
testIP: netip.MustParseAddr("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
@@ -192,10 +183,8 @@ func BenchmarkIPChecks(b *testing.B) {
|
|||||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup bitmap version
|
// Setup bitmap
|
||||||
bitmapManager := &localIPManager{
|
bitmapManager := newLocalIPManager()
|
||||||
ipv4Bitmap: [1 << 16]uint32{},
|
|
||||||
}
|
|
||||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||||
bitmapManager.setBitmapBit(ip)
|
bitmapManager.setBitmapBit(ip)
|
||||||
}
|
}
|
||||||
@@ -248,7 +237,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
|||||||
|
|
||||||
// Create two managers - one checks WG IP first, other checks it last
|
// Create two managers - one checks WG IP first, other checks it last
|
||||||
b.Run("WG_First", func(b *testing.B) {
|
b.Run("WG_First", func(b *testing.B) {
|
||||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
bm := newLocalIPManager()
|
||||||
bm.setBitmapBit(wgIP)
|
bm.setBitmapBit(wgIP)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
@@ -257,7 +246,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("WG_Last", func(b *testing.B) {
|
b.Run("WG_Last", func(b *testing.B) {
|
||||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
bm := newLocalIPManager()
|
||||||
// Fill with other IPs first
|
// Fill with other IPs first
|
||||||
for i := 0; i < 15; i++ {
|
for i := 0; i < 15; i++ {
|
||||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||||
|
|||||||
@@ -32,7 +32,8 @@ type RouteRule struct {
|
|||||||
id string
|
id string
|
||||||
mgmtId []byte
|
mgmtId []byte
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
destination netip.Prefix
|
dstSet firewall.Set
|
||||||
|
destinations []netip.Prefix
|
||||||
proto firewall.Protocol
|
proto firewall.Protocol
|
||||||
srcPort *firewall.Port
|
srcPort *firewall.Port
|
||||||
dstPort *firewall.Port
|
dstPort *firewall.Port
|
||||||
|
|||||||
@@ -38,11 +38,8 @@ func TestTracePacket(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -198,12 +195,12 @@ func TestTracePacket(t *testing.T) {
|
|||||||
m.forwarder.Store(&forwarder.Forwarder{})
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
@@ -222,12 +219,12 @@ func TestTracePacket(t *testing.T) {
|
|||||||
m.nativeRouter.Store(false)
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
@@ -245,7 +242,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
m.nativeRouter.Store(true)
|
m.nativeRouter.Store(true)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
@@ -263,7 +260,7 @@ func TestTracePacket(t *testing.T) {
|
|||||||
m.routingEnabled.Store(false)
|
m.routingEnabled.Store(false)
|
||||||
},
|
},
|
||||||
packetBuilder: func() *PacketBuilder {
|
packetBuilder: func() *PacketBuilder {
|
||||||
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
},
|
},
|
||||||
expectedStages: []PacketStage{
|
expectedStages: []PacketStage{
|
||||||
StageReceived,
|
StageReceived,
|
||||||
@@ -425,8 +422,8 @@ func TestTracePacket(t *testing.T) {
|
|||||||
|
|
||||||
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||||
"100.10.0.100 should be recognized as a local IP")
|
"100.10.0.100 should be recognized as a local IP")
|
||||||
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
|
||||||
"172.17.0.2 should not be recognized as a local IP")
|
"192.168.17.2 should not be recognized as a local IP")
|
||||||
|
|
||||||
pb := tc.packetBuilder()
|
pb := tc.packetBuilder()
|
||||||
|
|
||||||
|
|||||||
@@ -39,8 +39,12 @@ const (
|
|||||||
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
// EnvForceUserspaceRouter forces userspace routing even if native routing is available.
|
||||||
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
EnvForceUserspaceRouter = "NB_FORCE_USERSPACE_ROUTER"
|
||||||
|
|
||||||
// EnvEnableNetstackLocalForwarding enables forwarding of local traffic to the native stack when running netstack
|
// EnvEnableLocalForwarding enables forwarding of local traffic to the native stack for internal (non-NetBird) interfaces.
|
||||||
// Leaving this on by default introduces a security risk as sockets on listening on localhost only will be accessible
|
// Default off as it might be security risk because sockets listening on localhost only will become accessible.
|
||||||
|
EnvEnableLocalForwarding = "NB_ENABLE_LOCAL_FORWARDING"
|
||||||
|
|
||||||
|
// EnvEnableNetstackLocalForwarding is an alias for EnvEnableLocalForwarding.
|
||||||
|
// In netstack mode, it enables forwarding of local traffic to the native stack for all interfaces.
|
||||||
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
EnvEnableNetstackLocalForwarding = "NB_ENABLE_NETSTACK_LOCAL_FORWARDING"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -49,10 +53,10 @@ var errNatNotSupported = errors.New("nat not supported with userspace firewall")
|
|||||||
// RuleSet is a set of rules grouped by a string key
|
// RuleSet is a set of rules grouped by a string key
|
||||||
type RuleSet map[string]PeerRule
|
type RuleSet map[string]PeerRule
|
||||||
|
|
||||||
type RouteRules []RouteRule
|
type RouteRules []*RouteRule
|
||||||
|
|
||||||
func (r RouteRules) Sort() {
|
func (r RouteRules) Sort() {
|
||||||
slices.SortStableFunc(r, func(a, b RouteRule) int {
|
slices.SortStableFunc(r, func(a, b *RouteRule) int {
|
||||||
// Deny rules come first
|
// Deny rules come first
|
||||||
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
if a.action == firewall.ActionDrop && b.action != firewall.ActionDrop {
|
||||||
return -1
|
return -1
|
||||||
@@ -71,7 +75,6 @@ type Manager struct {
|
|||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[netip.Addr]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
wgIface common.IFaceMapper
|
wgIface common.IFaceMapper
|
||||||
nativeFirewall firewall.Manager
|
nativeFirewall firewall.Manager
|
||||||
@@ -99,6 +102,8 @@ type Manager struct {
|
|||||||
forwarder atomic.Pointer[forwarder.Forwarder]
|
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
flowLogger nftypes.FlowLogger
|
flowLogger nftypes.FlowLogger
|
||||||
|
|
||||||
|
blockRule firewall.Rule
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -146,6 +151,11 @@ func parseCreateEnv() (bool, bool) {
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
log.Warnf("failed to parse %s: %v", EnvEnableNetstackLocalForwarding, err)
|
||||||
}
|
}
|
||||||
|
} else if val := os.Getenv(EnvEnableLocalForwarding); val != "" {
|
||||||
|
enableLocalForwarding, err = strconv.ParseBool(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse %s: %v", EnvEnableLocalForwarding, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return disableConntrack, enableLocalForwarding
|
return disableConntrack, enableLocalForwarding
|
||||||
@@ -201,41 +211,35 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.blockInvalidRouted(iface); err != nil {
|
|
||||||
log.Errorf("failed to block invalid routed traffic: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := iface.SetFilter(m); err != nil {
|
if err := iface.SetFilter(m); err != nil {
|
||||||
return nil, fmt.Errorf("set filter: %w", err)
|
return nil, fmt.Errorf("set filter: %w", err)
|
||||||
}
|
}
|
||||||
return m, nil
|
return m, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) (firewall.Rule, error) {
|
||||||
if m.forwarder.Load() == nil {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("parse wireguard network: %w", err)
|
return nil, fmt.Errorf("parse wireguard network: %w", err)
|
||||||
}
|
}
|
||||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
if _, err := m.AddRouteFiltering(
|
rule, err := m.addRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
wgPrefix,
|
firewall.Network{Prefix: wgPrefix},
|
||||||
firewall.ProtocolALL,
|
firewall.ProtocolALL,
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionDrop,
|
firewall.ActionDrop,
|
||||||
); err != nil {
|
)
|
||||||
return fmt.Errorf("block wg nte : %w", err)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("block wg nte : %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TODO: Block networks that we're a client of
|
// TODO: Block networks that we're a client of
|
||||||
|
|
||||||
return nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) determineRouting() error {
|
func (m *Manager) determineRouting() error {
|
||||||
@@ -273,7 +277,7 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
log.Info("userspace routing is forced")
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
case !m.netstack && m.nativeFirewall != nil && m.nativeFirewall.IsServerRouteSupported():
|
case !m.netstack && m.nativeFirewall != nil:
|
||||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
// netstack mode won't support native routing as there is no interface
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
@@ -330,6 +334,10 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *Manager) IsStateful() bool {
|
||||||
|
return m.stateful
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
@@ -413,10 +421,23 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
id []byte,
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.addRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) addRouteFiltering(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination firewall.Network,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort, dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
@@ -429,31 +450,36 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
id: ruleID,
|
id: ruleID,
|
||||||
mgmtId: id,
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
destination: destination,
|
dstSet: destination.Set,
|
||||||
proto: proto,
|
proto: proto,
|
||||||
srcPort: sPort,
|
srcPort: sPort,
|
||||||
dstPort: dPort,
|
dstPort: dPort,
|
||||||
action: action,
|
action: action,
|
||||||
}
|
}
|
||||||
|
if destination.IsPrefix() {
|
||||||
|
rule.destinations = []netip.Prefix{destination.Prefix}
|
||||||
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.routeRules = append(m.routeRules, &rule)
|
||||||
m.routeRules = append(m.routeRules, rule)
|
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
m.mutex.Unlock()
|
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.deleteRouteRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) deleteRouteRule(rule firewall.Rule) error {
|
||||||
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
ruleID := rule.ID()
|
ruleID := rule.ID()
|
||||||
idx := slices.IndexFunc(m.routeRules, func(r RouteRule) bool {
|
idx := slices.IndexFunc(m.routeRules, func(r *RouteRule) bool {
|
||||||
return r.id == ruleID
|
return r.id == ruleID
|
||||||
})
|
})
|
||||||
if idx < 0 {
|
if idx < 0 {
|
||||||
@@ -509,6 +535,52 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return m.nativeFirewall.DeleteDNATRule(rule)
|
return m.nativeFirewall.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateSet updates the rule destinations associated with the given set
|
||||||
|
// by merging the existing prefixes with the new ones, then deduplicating.
|
||||||
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
|
return m.nativeFirewall.UpdateSet(set, prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
var matches []*RouteRule
|
||||||
|
for _, rule := range m.routeRules {
|
||||||
|
if rule.dstSet == set {
|
||||||
|
matches = append(matches, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(matches) == 0 {
|
||||||
|
return fmt.Errorf("no route rule found for set: %s", set)
|
||||||
|
}
|
||||||
|
|
||||||
|
destinations := matches[0].destinations
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
if prefix.Addr().Is4() {
|
||||||
|
destinations = append(destinations, prefix)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
slices.SortFunc(destinations, func(a, b netip.Prefix) int {
|
||||||
|
cmp := a.Addr().Compare(b.Addr())
|
||||||
|
if cmp != 0 {
|
||||||
|
return cmp
|
||||||
|
}
|
||||||
|
return a.Bits() - b.Bits()
|
||||||
|
})
|
||||||
|
|
||||||
|
destinations = slices.Compact(destinations)
|
||||||
|
|
||||||
|
for _, rule := range matches {
|
||||||
|
rule.destinations = destinations
|
||||||
|
}
|
||||||
|
log.Debugf("updated set %s to prefixes %v", set.HashedName(), destinations)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||||
return m.processOutgoingHooks(packetData, size)
|
return m.processOutgoingHooks(packetData, size)
|
||||||
@@ -546,9 +618,8 @@ func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.stateful {
|
// for netflow we keep track even if the firewall is stateless
|
||||||
m.trackOutbound(d, srcIP, dstIP, size)
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
@@ -658,7 +729,8 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
|||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if !m.isValidPacket(d, packetData) {
|
valid, fragment := m.isValidPacket(d, packetData)
|
||||||
|
if !valid {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -668,6 +740,13 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: pass fragments of routed packets to forwarder
|
||||||
|
if fragment {
|
||||||
|
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||||
|
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// For all inbound traffic, first check if it matches a tracked connection.
|
// For all inbound traffic, first check if it matches a tracked connection.
|
||||||
// This must happen before any other filtering because the packets are statefully tracked.
|
// This must happen before any other filtering because the packets are statefully tracked.
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||||
@@ -678,7 +757,7 @@ func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
|||||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleLocalTraffic handles local traffic.
|
// handleLocalTraffic handles local traffic.
|
||||||
@@ -709,9 +788,10 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// if running in netstack mode we need to pass this to the forwarder
|
// If requested we pass local traffic to internal interfaces to the forwarder.
|
||||||
if m.netstack && m.localForwarding {
|
// netstack doesn't have an interface to forward packets to the native stack so we always need to use the forwarder.
|
||||||
return m.handleNetstackLocalTraffic(packetData)
|
if m.localForwarding && (m.netstack || dstIP != m.wgIface.Address().IP) {
|
||||||
|
return m.handleForwardedLocalTraffic(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
// track inbound packets to get the correct direction and session id for flows
|
// track inbound packets to get the correct direction and session id for flows
|
||||||
@@ -721,8 +801,7 @@ func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packet
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
func (m *Manager) handleForwardedLocalTraffic(packetData []byte) bool {
|
||||||
|
|
||||||
fwd := m.forwarder.Load()
|
fwd := m.forwarder.Load()
|
||||||
if fwd == nil {
|
if fwd == nil {
|
||||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||||
@@ -739,7 +818,7 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
|||||||
|
|
||||||
// handleRoutedTraffic handles routed traffic.
|
// handleRoutedTraffic handles routed traffic.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte) bool {
|
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||||
// Drop if routing is disabled
|
// Drop if routing is disabled
|
||||||
if !m.routingEnabled.Load() {
|
if !m.routingEnabled.Load() {
|
||||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||||
@@ -749,13 +828,15 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
|||||||
|
|
||||||
// Pass to native stack if native router is enabled or forced
|
// Pass to native stack if native router is enabled or forced
|
||||||
if m.nativeRouter.Load() {
|
if m.nativeRouter.Load() {
|
||||||
|
m.trackInbound(d, srcIP, dstIP, nil, size)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
proto, pnum := getProtocolFromPacket(d)
|
proto, pnum := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
if !pass {
|
||||||
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
@@ -770,6 +851,8 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
|||||||
SourcePort: srcPort,
|
SourcePort: srcPort,
|
||||||
DestPort: dstPort,
|
DestPort: dstPort,
|
||||||
// TODO: icmp type/code
|
// TODO: icmp type/code
|
||||||
|
RxPackets: 1,
|
||||||
|
RxBytes: uint64(size),
|
||||||
})
|
})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
@@ -779,8 +862,11 @@ func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packe
|
|||||||
if fwd == nil {
|
if fwd == nil {
|
||||||
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
|
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
|
||||||
} else {
|
} else {
|
||||||
|
fwd.RegisterRuleID(srcIP, dstIP, srcPort, dstPort, ruleID)
|
||||||
|
|
||||||
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
||||||
m.logger.Error("Failed to inject routed packet: %v", err)
|
m.logger.Error("Failed to inject routed packet: %v", err)
|
||||||
|
fwd.DeleteRuleID(srcIP, dstIP, srcPort, dstPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -812,17 +898,32 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
// isValidPacket checks if the packet is valid.
|
||||||
|
// It returns true, false if the packet is valid and not a fragment.
|
||||||
|
// It returns true, true if the packet is a fragment and valid.
|
||||||
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
m.logger.Trace("couldn't decode packet, err: %s", err)
|
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||||
return false
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
l := len(d.decoded)
|
||||||
m.logger.Trace("packet doesn't have network and transport layers")
|
|
||||||
return false
|
// L3 and L4 are mandatory
|
||||||
|
if l >= 2 {
|
||||||
|
return true, false
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
|
// Fragments are also valid
|
||||||
|
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
||||||
|
ip4 := d.ip4
|
||||||
|
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("packet doesn't have network and transport layers")
|
||||||
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
||||||
@@ -962,8 +1063,15 @@ func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol
|
|||||||
return nil, false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
func (m *Manager) ruleMatches(rule *RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
if !rule.destination.Contains(dstAddr) {
|
destMatched := false
|
||||||
|
for _, dst := range rule.destinations {
|
||||||
|
if dst.Contains(dstAddr) {
|
||||||
|
destMatched = true
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !destMatched {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -991,11 +1099,6 @@ func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
func (m *Manager) SetNetwork(network *net.IPNet) {
|
|
||||||
m.wgNetwork = network
|
|
||||||
}
|
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
@@ -1065,7 +1168,22 @@ func (m *Manager) EnableRouting() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.determineRouting()
|
if err := m.determineRouting(); err != nil {
|
||||||
|
return fmt.Errorf("determine routing: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.forwarder.Load() == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := m.blockInvalidRouted(m.wgIface)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("block invalid routed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
m.blockRule = rule
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DisableRouting() error {
|
func (m *Manager) DisableRouting() error {
|
||||||
@@ -1090,5 +1208,12 @@ func (m *Manager) DisableRouting() error {
|
|||||||
|
|
||||||
log.Debug("forwarder stopped")
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
if m.blockRule != nil {
|
||||||
|
if err := m.deleteRouteRule(m.blockRule); err != nil {
|
||||||
|
return fmt.Errorf("delete block rule: %w", err)
|
||||||
|
}
|
||||||
|
m.blockRule = nil
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -174,11 +174,6 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Apply scenario-specific setup
|
// Apply scenario-specific setup
|
||||||
sc.setupFunc(manager)
|
sc.setupFunc(manager)
|
||||||
|
|
||||||
@@ -219,11 +214,6 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
// Pre-populate connection table
|
// Pre-populate connection table
|
||||||
srcIPs := generateRandomIPs(count)
|
srcIPs := generateRandomIPs(count)
|
||||||
dstIPs := generateRandomIPs(count)
|
dstIPs := generateRandomIPs(count)
|
||||||
@@ -267,11 +257,6 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
srcIP := generateRandomIPs(1)[0]
|
srcIP := generateRandomIPs(1)[0]
|
||||||
dstIP := generateRandomIPs(1)[0]
|
dstIP := generateRandomIPs(1)[0]
|
||||||
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
outbound := generatePacket(b, srcIP, dstIP, 1024, 80, layers.IPProtocolTCP)
|
||||||
@@ -304,10 +289,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -321,10 +302,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -339,10 +316,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -356,10 +329,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
}
|
|
||||||
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
b.Setenv("NB_DISABLE_CONNTRACK", "1")
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -373,10 +342,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -390,10 +355,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -408,10 +369,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolTCP,
|
proto: layers.IPProtocolTCP,
|
||||||
state: "post_handshake",
|
state: "post_handshake",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -426,10 +383,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "new",
|
state: "new",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -443,10 +396,6 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
proto: layers.IPProtocolUDP,
|
proto: layers.IPProtocolUDP,
|
||||||
state: "established",
|
state: "established",
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("0.0.0.0"),
|
|
||||||
Mask: net.CIDRMask(0, 32),
|
|
||||||
}
|
|
||||||
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
require.NoError(b, os.Unsetenv("NB_DISABLE_CONNTRACK"))
|
||||||
},
|
},
|
||||||
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
genPackets: func(srcIP, dstIP net.IP) ([]byte, []byte) {
|
||||||
@@ -593,11 +542,6 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@@ -681,11 +625,6 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
@@ -797,11 +736,6 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
@@ -882,11 +816,6 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.SetNetwork(&net.IPNet{
|
|
||||||
IP: net.ParseIP("100.64.0.0"),
|
|
||||||
Mask: net.CIDRMask(10, 32),
|
|
||||||
})
|
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
@@ -1032,7 +961,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
dst := fw.Network{Prefix: r.dest}
|
||||||
|
_, err := manager.AddRouteFiltering(nil, r.sources, dst, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,15 +15,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPeerACLFiltering(t *testing.T) {
|
func TestPeerACLFiltering(t *testing.T) {
|
||||||
localIP := net.ParseIP("100.10.0.100")
|
localIP := netip.MustParseAddr("100.10.0.100")
|
||||||
wgNet := &net.IPNet{
|
wgNet := netip.MustParsePrefix("100.10.0.0/16")
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
@@ -42,8 +39,6 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
|
|
||||||
manager.wgNetwork = wgNet
|
|
||||||
|
|
||||||
err = manager.UpdateLocalIPs()
|
err = manager.UpdateLocalIPs()
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
@@ -188,6 +183,281 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
ruleAction: fw.ActionAccept,
|
ruleAction: fw.ActionAccept,
|
||||||
shouldBeBlocked: true,
|
shouldBeBlocked: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Allow TCP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow UDP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP packet doesn't match UDP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP packet doesn't match TCP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match TCP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match UDP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow TCP traffic within port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block TCP traffic outside port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 7999,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Edge Case - Port at Range Boundary",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8100,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP Port Range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 5060,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow multiple destination ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow multiple source ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
// New drop test cases
|
||||||
|
{
|
||||||
|
name: "Drop TCP traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop UDP traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop ICMP traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolICMP,
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop all traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolALL,
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop traffic from multiple source ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop multiple destination ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop TCP traffic within port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Accept TCP traffic outside drop port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 7999,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop TCP traffic with source port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 32100,
|
||||||
|
dstPort: 80,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed rule - drop specific port but allow other ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||||
@@ -198,6 +468,28 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
if tc.ruleAction == fw.ActionDrop {
|
||||||
|
// add general accept rule to test drop rule
|
||||||
|
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
|
||||||
|
rules, err := manager.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
fw.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, rules)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
for _, rule := range rules {
|
||||||
|
require.NoError(t, manager.DeletePeerRule(rule))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
rules, err := manager.AddPeerFiltering(
|
rules, err := manager.AddPeerFiltering(
|
||||||
nil,
|
nil,
|
||||||
net.ParseIP(tc.ruleIP),
|
net.ParseIP(tc.ruleIP),
|
||||||
@@ -283,14 +575,13 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
dev := mocks.NewMockDevice(ctrl)
|
dev := mocks.NewMockDevice(ctrl)
|
||||||
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
dev.EXPECT().MTU().Return(1500, nil).AnyTimes()
|
||||||
|
|
||||||
localIP, wgNet, err := net.ParseCIDR(network)
|
wgNet := netip.MustParsePrefix(network)
|
||||||
require.NoError(tb, err)
|
|
||||||
|
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: localIP,
|
IP: wgNet.Addr(),
|
||||||
Network: wgNet,
|
Network: wgNet,
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -303,8 +594,8 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false, flowLogger)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
require.True(tb, manager.routingEnabled.Load())
|
require.True(tb, manager.routingEnabled.Load())
|
||||||
require.False(tb, manager.nativeRouter.Load())
|
require.False(tb, manager.nativeRouter.Load())
|
||||||
@@ -321,7 +612,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
type rule struct {
|
type rule struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -347,7 +638,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -363,7 +654,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -379,7 +670,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -395,7 +686,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolUDP,
|
proto: fw.ProtocolUDP,
|
||||||
dstPort: &fw.Port{Values: []uint16{53}},
|
dstPort: &fw.Port{Values: []uint16{53}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -409,7 +700,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
@@ -424,7 +715,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -440,7 +731,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -456,7 +747,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -472,7 +763,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -488,7 +779,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -507,7 +798,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
netip.MustParsePrefix("100.10.0.0/16"),
|
netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
},
|
},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -521,7 +812,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
@@ -536,33 +827,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
shouldPass: true,
|
shouldPass: true,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "Multiple source networks with mismatched protocol",
|
|
||||||
srcIP: "172.16.0.1",
|
|
||||||
dstIP: "192.168.1.100",
|
|
||||||
// Should not match TCP rule
|
|
||||||
proto: fw.ProtocolUDP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 80,
|
|
||||||
rule: rule{
|
|
||||||
sources: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("100.10.0.0/16"),
|
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
|
||||||
},
|
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
proto: fw.ProtocolTCP,
|
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
|
||||||
action: fw.ActionAccept,
|
|
||||||
},
|
|
||||||
shouldPass: false,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "Allow multiple destination ports",
|
name: "Allow multiple destination ports",
|
||||||
srcIP: "100.10.0.1",
|
srcIP: "100.10.0.1",
|
||||||
@@ -572,7 +843,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -588,7 +859,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -604,7 +875,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
@@ -621,7 +892,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -640,7 +911,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 7999,
|
dstPort: 7999,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -659,7 +930,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{
|
srcPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -678,7 +949,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{
|
srcPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -700,7 +971,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8100,
|
dstPort: 8100,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -719,7 +990,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 5060,
|
dstPort: 5060,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolUDP,
|
proto: fw.ProtocolUDP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -738,7 +1009,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -757,7 +1028,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -773,7 +1044,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
},
|
},
|
||||||
@@ -791,17 +1062,158 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
netip.MustParsePrefix("100.10.0.0/16"),
|
netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
},
|
},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
},
|
},
|
||||||
shouldPass: false,
|
shouldPass: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "Drop empty destination set",
|
||||||
|
srcIP: "172.16.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
},
|
||||||
|
dest: fw.Network{Set: fw.Set{}},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Accept TCP traffic outside drop port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 7999,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
action: fw.ActionDrop,
|
||||||
|
},
|
||||||
|
shouldPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow TCP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow UDP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP packet doesn't match UDP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP packet doesn't match TCP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match TCP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match UDP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if tc.rule.action == fw.ActionDrop {
|
||||||
|
// add general accept rule to test drop rule
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
nil,
|
nil,
|
||||||
tc.rule.sources,
|
tc.rule.sources,
|
||||||
@@ -836,7 +1248,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
rules []struct {
|
rules []struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -857,7 +1269,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
name: "Drop rules take precedence over accept",
|
name: "Drop rules take precedence over accept",
|
||||||
rules: []struct {
|
rules: []struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -866,7 +1278,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Accept rule added first
|
// Accept rule added first
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80, 443}},
|
dstPort: &fw.Port{Values: []uint16{80, 443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -874,7 +1286,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Drop rule added second but should be evaluated first
|
// Drop rule added second but should be evaluated first
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -912,7 +1324,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
name: "Multiple drop rules take precedence",
|
name: "Multiple drop rules take precedence",
|
||||||
rules: []struct {
|
rules: []struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -921,14 +1333,14 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Accept all
|
// Accept all
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Drop specific port
|
// Drop specific port
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -936,7 +1348,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Drop different port
|
// Drop different port
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -1015,3 +1427,50 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouteACLSet(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
|
// Add rule that uses the set (initially empty)
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Set: set},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
|
||||||
|
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||||
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Now the packet should be allowed
|
||||||
|
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||||
|
}
|
||||||
|
|||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
@@ -270,11 +271,8 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
AddressFunc: func() wgaddr.Address {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return wgaddr.Address{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.10.0.100"),
|
IP: netip.MustParseAddr("100.10.0.100"),
|
||||||
Network: &net.IPNet{
|
Network: netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -284,10 +282,6 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
m.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
@@ -395,10 +389,6 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
@@ -508,11 +498,6 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}, false, flowLogger)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
|
||||||
IP: net.ParseIP("100.10.0.0"),
|
|
||||||
Mask: net.CIDRMask(16, 32),
|
|
||||||
}
|
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
@@ -711,3 +696,203 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateSetMerge(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
|
initialPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Set: set},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
|
// Update the set with initial prefixes
|
||||||
|
err = manager.UpdateSet(set, initialPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test initial prefixes work
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
dstIP1 := netip.MustParseAddr("10.0.0.100")
|
||||||
|
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||||
|
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||||
|
|
||||||
|
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||||
|
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||||
|
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
|
||||||
|
|
||||||
|
newPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, newPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check that all original prefixes are still included
|
||||||
|
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||||
|
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||||
|
|
||||||
|
// Check that new prefixes are included
|
||||||
|
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||||
|
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||||
|
|
||||||
|
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||||
|
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||||
|
|
||||||
|
// Verify the rule has all prefixes
|
||||||
|
manager.mutex.RLock()
|
||||||
|
foundRule := false
|
||||||
|
for _, r := range manager.routeRules {
|
||||||
|
if r.id == rule.ID() {
|
||||||
|
foundRule = true
|
||||||
|
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
|
||||||
|
"Rule should have all prefixes merged")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
require.True(t, foundRule, "Rule should be found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateSetDeduplication(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Set: set},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
|
initialPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, initialPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check the internal state for deduplication
|
||||||
|
manager.mutex.RLock()
|
||||||
|
foundRule := false
|
||||||
|
for _, r := range manager.routeRules {
|
||||||
|
if r.id == rule.ID() {
|
||||||
|
foundRule = true
|
||||||
|
// Should have deduplicated to 2 prefixes
|
||||||
|
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
|
||||||
|
|
||||||
|
// Check the prefixes are correct
|
||||||
|
expectedPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
}
|
||||||
|
for i, prefix := range expectedPrefixes {
|
||||||
|
require.True(t, r.destinations[i] == prefix,
|
||||||
|
"Prefix should match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
require.True(t, foundRule, "Rule should be found")
|
||||||
|
|
||||||
|
// Test with overlapping prefixes of different sizes
|
||||||
|
overlappingPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/16"), // More general
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"), // More general
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, overlappingPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check that all prefixes are included (no deduplication of overlapping prefixes)
|
||||||
|
manager.mutex.RLock()
|
||||||
|
for _, r := range manager.routeRules {
|
||||||
|
if r.id == rule.ID() {
|
||||||
|
// Should have all 4 prefixes (2 original + 2 new more general ones)
|
||||||
|
require.Len(t, r.destinations, 4,
|
||||||
|
"Overlapping prefixes should not be deduplicated")
|
||||||
|
|
||||||
|
// Verify they're sorted correctly (more specific prefixes should come first)
|
||||||
|
prefixes := make([]string, 0, len(r.destinations))
|
||||||
|
for _, p := range r.destinations {
|
||||||
|
prefixes = append(prefixes, p.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check sorted order
|
||||||
|
require.Equal(t, []string{
|
||||||
|
"10.0.0.0/16",
|
||||||
|
"10.0.0.0/24",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"192.168.1.0/24",
|
||||||
|
}, prefixes, "Prefixes should be sorted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Test functionality with all prefixes
|
||||||
|
testCases := []struct {
|
||||||
|
dstIP netip.Addr
|
||||||
|
expected bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
|
||||||
|
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
|
||||||
|
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
|
||||||
|
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
|
||||||
|
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
for _, tc := range testCases {
|
||||||
|
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
|
|||||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||||
if params.Logger == nil {
|
if params.Logger == nil {
|
||||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
params.Logger = getLogger()
|
||||||
}
|
}
|
||||||
|
|
||||||
mux := &UDPMuxDefault{
|
mux := &UDPMuxDefault{
|
||||||
@@ -455,3 +455,9 @@ func newBufferHolder(size int) *bufferHolder {
|
|||||||
buf: make([]byte, size),
|
buf: make([]byte, size),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLogger() logging.LeveledLogger {
|
||||||
|
fac := logging.NewDefaultLoggerFactory()
|
||||||
|
//fac.Writer = log.StandardLogger().Writer()
|
||||||
|
return fac.NewLogger("ice")
|
||||||
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ type UniversalUDPMuxParams struct {
|
|||||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||||
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
||||||
if params.Logger == nil {
|
if params.Logger == nil {
|
||||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
params.Logger = getLogger()
|
||||||
}
|
}
|
||||||
if params.XORMappedAddrCacheTTL == 0 {
|
if params.XORMappedAddrCacheTTL == 0 {
|
||||||
params.XORMappedAddrCacheTTL = time.Second * 25
|
params.XORMappedAddrCacheTTL = time.Second * 25
|
||||||
@@ -164,7 +164,7 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if u.address.Network.Contains(a.AsSlice()) {
|
if u.address.Network.Contains(a) {
|
||||||
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var zeroKey wgtypes.Key
|
||||||
|
|
||||||
type KernelConfigurer struct {
|
type KernelConfigurer struct {
|
||||||
deviceName string
|
deviceName string
|
||||||
}
|
}
|
||||||
@@ -201,14 +203,71 @@ func (c *KernelConfigurer) configure(config wgtypes.Config) error {
|
|||||||
func (c *KernelConfigurer) Close() {
|
func (c *KernelConfigurer) Close() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) GetStats(peerKey string) (WGStats, error) {
|
func (c *KernelConfigurer) FullStats() (*Stats, error) {
|
||||||
peer, err := c.getPeer(c.deviceName, peerKey)
|
wg, err := wgctrl.New()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGStats{}, fmt.Errorf("get wireguard stats: %w", err)
|
return nil, fmt.Errorf("wgctl: %w", err)
|
||||||
}
|
}
|
||||||
return WGStats{
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Got error while closing wgctl: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wgDevice, err := wg.Device(c.deviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||||
|
}
|
||||||
|
fullStats := &Stats{
|
||||||
|
DeviceName: wgDevice.Name,
|
||||||
|
PublicKey: wgDevice.PublicKey.String(),
|
||||||
|
ListenPort: wgDevice.ListenPort,
|
||||||
|
FWMark: wgDevice.FirewallMark,
|
||||||
|
Peers: []Peer{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, p := range wgDevice.Peers {
|
||||||
|
peer := Peer{
|
||||||
|
PublicKey: p.PublicKey.String(),
|
||||||
|
AllowedIPs: p.AllowedIPs,
|
||||||
|
TxBytes: p.TransmitBytes,
|
||||||
|
RxBytes: p.ReceiveBytes,
|
||||||
|
LastHandshake: p.LastHandshakeTime,
|
||||||
|
PresharedKey: p.PresharedKey != zeroKey,
|
||||||
|
}
|
||||||
|
if p.Endpoint != nil {
|
||||||
|
peer.Endpoint = *p.Endpoint
|
||||||
|
}
|
||||||
|
fullStats.Peers = append(fullStats.Peers, peer)
|
||||||
|
}
|
||||||
|
return fullStats, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *KernelConfigurer) GetStats() (map[string]WGStats, error) {
|
||||||
|
stats := make(map[string]WGStats)
|
||||||
|
wg, err := wgctrl.New()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("wgctl: %w", err)
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
err = wg.Close()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Got error while closing wgctl: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
wgDevice, err := wg.Device(c.deviceName)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get device %s: %w", c.deviceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range wgDevice.Peers {
|
||||||
|
stats[peer.PublicKey.String()] = WGStats{
|
||||||
LastHandshake: peer.LastHandshakeTime,
|
LastHandshake: peer.LastHandshakeTime,
|
||||||
TxBytes: peer.TransmitBytes,
|
TxBytes: peer.TransmitBytes,
|
||||||
RxBytes: peer.ReceiveBytes,
|
RxBytes: peer.ReceiveBytes,
|
||||||
}, nil
|
}
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"encoding/base64"
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
@@ -17,6 +18,20 @@ import (
|
|||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
privateKey = "private_key"
|
||||||
|
ipcKeyLastHandshakeTimeSec = "last_handshake_time_sec"
|
||||||
|
ipcKeyLastHandshakeTimeNsec = "last_handshake_time_nsec"
|
||||||
|
ipcKeyTxBytes = "tx_bytes"
|
||||||
|
ipcKeyRxBytes = "rx_bytes"
|
||||||
|
allowedIP = "allowed_ip"
|
||||||
|
endpoint = "endpoint"
|
||||||
|
fwmark = "fwmark"
|
||||||
|
listenPort = "listen_port"
|
||||||
|
publicKey = "public_key"
|
||||||
|
presharedKey = "preshared_key"
|
||||||
|
)
|
||||||
|
|
||||||
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
var ErrAllowedIPNotFound = fmt.Errorf("allowed IP not found")
|
||||||
|
|
||||||
type WGUSPConfigurer struct {
|
type WGUSPConfigurer struct {
|
||||||
@@ -178,6 +193,15 @@ func (c *WGUSPConfigurer) RemoveAllowedIP(peerKey string, ip string) error {
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *WGUSPConfigurer) FullStats() (*Stats, error) {
|
||||||
|
ipcStr, err := c.device.IpcGet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("IpcGet failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return parseStatus(c.deviceName, ipcStr)
|
||||||
|
}
|
||||||
|
|
||||||
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
// startUAPI starts the UAPI listener for managing the WireGuard interface via external tool
|
||||||
func (t *WGUSPConfigurer) startUAPI() {
|
func (t *WGUSPConfigurer) startUAPI() {
|
||||||
var err error
|
var err error
|
||||||
@@ -217,91 +241,75 @@ func (t *WGUSPConfigurer) Close() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGUSPConfigurer) GetStats(peerKey string) (WGStats, error) {
|
func (t *WGUSPConfigurer) GetStats() (map[string]WGStats, error) {
|
||||||
ipc, err := t.device.IpcGet()
|
ipc, err := t.device.IpcGet()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return WGStats{}, fmt.Errorf("ipc get: %w", err)
|
return nil, fmt.Errorf("ipc get: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
stats, err := findPeerInfo(ipc, peerKey, []string{
|
return parseTransfers(ipc)
|
||||||
"last_handshake_time_sec",
|
|
||||||
"last_handshake_time_nsec",
|
|
||||||
"tx_bytes",
|
|
||||||
"rx_bytes",
|
|
||||||
})
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("find peer info: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
sec, err := strconv.ParseInt(stats["last_handshake_time_sec"], 10, 64)
|
func parseTransfers(ipc string) (map[string]WGStats, error) {
|
||||||
if err != nil {
|
stats := make(map[string]WGStats)
|
||||||
return WGStats{}, fmt.Errorf("parse handshake sec: %w", err)
|
var (
|
||||||
}
|
currentKey string
|
||||||
nsec, err := strconv.ParseInt(stats["last_handshake_time_nsec"], 10, 64)
|
currentStats WGStats
|
||||||
if err != nil {
|
hasPeer bool
|
||||||
return WGStats{}, fmt.Errorf("parse handshake nsec: %w", err)
|
)
|
||||||
}
|
lines := strings.Split(ipc, "\n")
|
||||||
txBytes, err := strconv.ParseInt(stats["tx_bytes"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse tx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
rxBytes, err := strconv.ParseInt(stats["rx_bytes"], 10, 64)
|
|
||||||
if err != nil {
|
|
||||||
return WGStats{}, fmt.Errorf("parse rx_bytes: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return WGStats{
|
|
||||||
LastHandshake: time.Unix(sec, nsec),
|
|
||||||
TxBytes: txBytes,
|
|
||||||
RxBytes: rxBytes,
|
|
||||||
}, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func findPeerInfo(ipcInput string, peerKey string, searchConfigKeys []string) (map[string]string, error) {
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("parse key: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
hexKey := hex.EncodeToString(peerKeyParsed[:])
|
|
||||||
|
|
||||||
lines := strings.Split(ipcInput, "\n")
|
|
||||||
|
|
||||||
configFound := map[string]string{}
|
|
||||||
foundPeer := false
|
|
||||||
for _, line := range lines {
|
for _, line := range lines {
|
||||||
line = strings.TrimSpace(line)
|
line = strings.TrimSpace(line)
|
||||||
|
|
||||||
// If we're within the details of the found peer and encounter another public key,
|
// If we're within the details of the found peer and encounter another public key,
|
||||||
// this means we're starting another peer's details. So, stop.
|
// this means we're starting another peer's details. So, stop.
|
||||||
if strings.HasPrefix(line, "public_key=") && foundPeer {
|
if strings.HasPrefix(line, "public_key=") {
|
||||||
break
|
peerID := strings.TrimPrefix(line, "public_key=")
|
||||||
|
h, err := hex.DecodeString(peerID)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("decode peerID: %w", err)
|
||||||
|
}
|
||||||
|
currentKey = base64.StdEncoding.EncodeToString(h)
|
||||||
|
currentStats = WGStats{} // Reset stats for the new peer
|
||||||
|
hasPeer = true
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Identify the peer with the specific public key
|
if !hasPeer {
|
||||||
if line == fmt.Sprintf("public_key=%s", hexKey) {
|
continue
|
||||||
foundPeer = true
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, key := range searchConfigKeys {
|
key := strings.SplitN(line, "=", 2)
|
||||||
if foundPeer && strings.HasPrefix(line, key+"=") {
|
if len(key) != 2 {
|
||||||
v := strings.SplitN(line, "=", 2)
|
continue
|
||||||
configFound[v[0]] = v[1]
|
|
||||||
}
|
}
|
||||||
|
switch key[0] {
|
||||||
|
case ipcKeyLastHandshakeTimeSec:
|
||||||
|
hs, err := toLastHandshake(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
currentStats.LastHandshake = hs
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
case ipcKeyRxBytes:
|
||||||
|
rxBytes, err := toBytes(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse rx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
currentStats.RxBytes = rxBytes
|
||||||
|
stats[currentKey] = currentStats
|
||||||
|
case ipcKeyTxBytes:
|
||||||
|
TxBytes, err := toBytes(key[1])
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse tx_bytes: %w", err)
|
||||||
|
}
|
||||||
|
currentStats.TxBytes = TxBytes
|
||||||
|
stats[currentKey] = currentStats
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo: use multierr
|
return stats, nil
|
||||||
for _, key := range searchConfigKeys {
|
|
||||||
if _, ok := configFound[key]; !ok {
|
|
||||||
return configFound, fmt.Errorf("config key not found: %s", key)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
if !foundPeer {
|
|
||||||
return nil, fmt.Errorf("%w: %s", ErrPeerNotFound, peerKey)
|
|
||||||
}
|
|
||||||
|
|
||||||
return configFound, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
||||||
@@ -355,9 +363,154 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
return sb.String()
|
return sb.String()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toLastHandshake(stringVar string) (time.Time, error) {
|
||||||
|
sec, err := strconv.ParseInt(stringVar, 10, 64)
|
||||||
|
if err != nil {
|
||||||
|
return time.Time{}, fmt.Errorf("parse handshake sec: %w", err)
|
||||||
|
}
|
||||||
|
return time.Unix(sec, 0), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func toBytes(s string) (int64, error) {
|
||||||
|
return strconv.ParseInt(s, 10, 64)
|
||||||
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.NetbirdFwmark
|
return nbnet.ControlPlaneMark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func hexToWireguardKey(hexKey string) (wgtypes.Key, error) {
|
||||||
|
// Decode hex string to bytes
|
||||||
|
keyBytes, err := hex.DecodeString(hexKey)
|
||||||
|
if err != nil {
|
||||||
|
return wgtypes.Key{}, fmt.Errorf("failed to decode hex key: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if we have the right number of bytes (WireGuard keys are 32 bytes)
|
||||||
|
if len(keyBytes) != 32 {
|
||||||
|
return wgtypes.Key{}, fmt.Errorf("invalid key length: expected 32 bytes, got %d", len(keyBytes))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Convert to wgtypes.Key
|
||||||
|
var key wgtypes.Key
|
||||||
|
copy(key[:], keyBytes)
|
||||||
|
|
||||||
|
return key, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func parseStatus(deviceName, ipcStr string) (*Stats, error) {
|
||||||
|
stats := &Stats{DeviceName: deviceName}
|
||||||
|
var currentPeer *Peer
|
||||||
|
for _, line := range strings.Split(strings.TrimSpace(ipcStr), "\n") {
|
||||||
|
if line == "" {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
parts := strings.SplitN(line, "=", 2)
|
||||||
|
if len(parts) != 2 {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
key := parts[0]
|
||||||
|
val := parts[1]
|
||||||
|
|
||||||
|
switch key {
|
||||||
|
case privateKey:
|
||||||
|
key, err := hexToWireguardKey(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse private key: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
stats.PublicKey = key.PublicKey().String()
|
||||||
|
case publicKey:
|
||||||
|
// Save previous peer
|
||||||
|
if currentPeer != nil {
|
||||||
|
stats.Peers = append(stats.Peers, *currentPeer)
|
||||||
|
}
|
||||||
|
key, err := hexToWireguardKey(val)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse public key: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer = &Peer{
|
||||||
|
PublicKey: key.String(),
|
||||||
|
}
|
||||||
|
case listenPort:
|
||||||
|
if port, err := strconv.Atoi(val); err == nil {
|
||||||
|
stats.ListenPort = port
|
||||||
|
}
|
||||||
|
case fwmark:
|
||||||
|
if fwmark, err := strconv.Atoi(val); err == nil {
|
||||||
|
stats.FWMark = fwmark
|
||||||
|
}
|
||||||
|
case endpoint:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
host, portStr, err := net.SplitHostPort(strings.Trim(val, "[]"))
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse endpoint: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
port, err := strconv.Atoi(portStr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("failed to parse endpoint port: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.Endpoint = net.UDPAddr{
|
||||||
|
IP: net.ParseIP(host),
|
||||||
|
Port: port,
|
||||||
|
}
|
||||||
|
case allowedIP:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_, ipnet, err := net.ParseCIDR(val)
|
||||||
|
if err == nil {
|
||||||
|
currentPeer.AllowedIPs = append(currentPeer.AllowedIPs, *ipnet)
|
||||||
|
}
|
||||||
|
case ipcKeyTxBytes:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rxBytes, err := toBytes(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.TxBytes = rxBytes
|
||||||
|
case ipcKeyRxBytes:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
rxBytes, err := toBytes(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.RxBytes = rxBytes
|
||||||
|
|
||||||
|
case ipcKeyLastHandshakeTimeSec:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
ts, err := toLastHandshake(val)
|
||||||
|
if err != nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
currentPeer.LastHandshake = ts
|
||||||
|
case presharedKey:
|
||||||
|
if currentPeer == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if val != "" {
|
||||||
|
currentPeer.PresharedKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if currentPeer != nil {
|
||||||
|
stats.Peers = append(stats.Peers, *currentPeer)
|
||||||
|
}
|
||||||
|
return stats, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,10 +2,8 @@ package configurer
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/hex"
|
"encoding/hex"
|
||||||
"fmt"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
"golang.zx2c4.com/wireguard/wgctrl/wgtypes"
|
||||||
)
|
)
|
||||||
@@ -34,58 +32,35 @@ errno=0
|
|||||||
|
|
||||||
`
|
`
|
||||||
|
|
||||||
func Test_findPeerInfo(t *testing.T) {
|
func Test_parseTransfers(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
peerKey string
|
peerKey string
|
||||||
searchKeys []string
|
want WGStats
|
||||||
want map[string]string
|
|
||||||
wantErr bool
|
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "single",
|
name: "single",
|
||||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
peerKey: "b85996fecc9c7f1fc6d2572a76eda11d59bcd20be8e543b15ce4bd85a8e75a33",
|
||||||
searchKeys: []string{"tx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 0,
|
||||||
"tx_bytes": "38333",
|
RxBytes: 0,
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "multiple",
|
name: "multiple",
|
||||||
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
peerKey: "58402e695ba1772b1cc9309755f043251ea77fdcf10fbe63989ceb7e19321376",
|
||||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 38333,
|
||||||
"tx_bytes": "38333",
|
RxBytes: 2224,
|
||||||
"rx_bytes": "2224",
|
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "lastpeer",
|
name: "lastpeer",
|
||||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
||||||
searchKeys: []string{"tx_bytes", "rx_bytes"},
|
want: WGStats{
|
||||||
want: map[string]string{
|
TxBytes: 1212111,
|
||||||
"tx_bytes": "1212111",
|
RxBytes: 1929999999,
|
||||||
"rx_bytes": "1929999999",
|
|
||||||
},
|
},
|
||||||
wantErr: false,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "peer not found",
|
|
||||||
peerKey: "1111111111111111111111111111111111111111111111111111111111111111",
|
|
||||||
searchKeys: nil,
|
|
||||||
want: nil,
|
|
||||||
wantErr: true,
|
|
||||||
},
|
|
||||||
{
|
|
||||||
name: "key not found",
|
|
||||||
peerKey: "662e14fd594556f522604703340351258903b64f35553763f19426ab2a515c58",
|
|
||||||
searchKeys: []string{"tx_bytes", "unknown_key"},
|
|
||||||
want: map[string]string{
|
|
||||||
"tx_bytes": "1212111",
|
|
||||||
},
|
|
||||||
wantErr: true,
|
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
@@ -96,9 +71,19 @@ func Test_findPeerInfo(t *testing.T) {
|
|||||||
key, err := wgtypes.NewKey(res)
|
key, err := wgtypes.NewKey(res)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
got, err := findPeerInfo(ipcFixture, key.String(), tt.searchKeys)
|
stats, err := parseTransfers(ipcFixture)
|
||||||
assert.Equalf(t, tt.wantErr, err != nil, fmt.Sprintf("findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys))
|
if err != nil {
|
||||||
assert.Equalf(t, tt.want, got, "findPeerInfo(%v, %v, %v)", ipcFixture, key.String(), tt.searchKeys)
|
require.NoError(t, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
stat, ok := stats[key.String()]
|
||||||
|
if !ok {
|
||||||
|
require.True(t, ok)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, tt.want, stat)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
24
client/iface/configurer/wgshow.go
Normal file
24
client/iface/configurer/wgshow.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package configurer
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Peer struct {
|
||||||
|
PublicKey string
|
||||||
|
Endpoint net.UDPAddr
|
||||||
|
AllowedIPs []net.IPNet
|
||||||
|
TxBytes int64
|
||||||
|
RxBytes int64
|
||||||
|
LastHandshake time.Time
|
||||||
|
PresharedKey bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type Stats struct {
|
||||||
|
DeviceName string
|
||||||
|
PublicKey string
|
||||||
|
ListenPort int
|
||||||
|
FWMark int
|
||||||
|
Peers []Peer
|
||||||
|
}
|
||||||
@@ -14,6 +14,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
// WGTunDevice ignore the WGTunDevice interface on Android because the creation of the tun device is different on this platform
|
||||||
@@ -43,11 +44,11 @@ func NewTunDevice(address wgaddr.Address, port int, key string, mtu int, iceBind
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains []string) (WGConfigurer, error) {
|
func (t *WGTunDevice) Create(routes []string, dns string, searchDomains domain.List) (WGConfigurer, error) {
|
||||||
log.Info("create tun interface")
|
log.Info("create tun interface")
|
||||||
|
|
||||||
routesString := routesToString(routes)
|
routesString := routesToString(routes)
|
||||||
searchDomainsToString := searchDomainsToString(searchDomains)
|
searchDomainsToString := searchDomainsToString(searchDomains.ToPunycodeList())
|
||||||
|
|
||||||
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
fd, err := t.tunAdapter.ConfigureInterface(t.address.String(), t.mtu, dns, searchDomainsToString, routesString)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package device
|
package device
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -24,9 +23,6 @@ type PacketFilter interface {
|
|||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
|
|
||||||
// SetNetwork of the wireguard interface to which filtering applied
|
|
||||||
SetNetwork(*net.IPNet)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// FilteredDevice to override Read or Write of packets
|
// FilteredDevice to override Read or Write of packets
|
||||||
|
|||||||
@@ -51,7 +51,11 @@ func (t *TunNetstackDevice) Create() (WGConfigurer, error) {
|
|||||||
log.Info("create nbnetstack tun interface")
|
log.Info("create nbnetstack tun interface")
|
||||||
|
|
||||||
// TODO: get from service listener runtime IP
|
// TODO: get from service listener runtime IP
|
||||||
dnsAddr := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
dnsAddr, err := nbnet.GetLastIPFromNetwork(t.address.Network, 1)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("last ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
log.Debugf("netstack using address: %s", t.address.IP)
|
log.Debugf("netstack using address: %s", t.address.IP)
|
||||||
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
t.nsTun = nbnetstack.NewNetStackTun(t.listenAddress, t.address.IP, dnsAddr, t.mtu)
|
||||||
log.Debugf("netstack using dns address: %s", dnsAddr)
|
log.Debugf("netstack using dns address: %s", dnsAddr)
|
||||||
|
|||||||
@@ -16,5 +16,6 @@ type WGConfigurer interface {
|
|||||||
AddAllowedIP(peerKey string, allowedIP string) error
|
AddAllowedIP(peerKey string, allowedIP string) error
|
||||||
RemoveAllowedIP(peerKey string, allowedIP string) error
|
RemoveAllowedIP(peerKey string, allowedIP string) error
|
||||||
Close()
|
Close()
|
||||||
GetStats(peerKey string) (configurer.WGStats, error)
|
GetStats() (map[string]configurer.WGStats, error)
|
||||||
|
FullStats() (*configurer.Stats, error)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,7 +64,15 @@ func (l *wgLink) assignAddr(address wgaddr.Address) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
ip := address.IP.String()
|
ip := address.IP.String()
|
||||||
mask := "0x" + address.Network.Mask.String()
|
|
||||||
|
// Convert prefix length to hex netmask
|
||||||
|
prefixLen := address.Network.Bits()
|
||||||
|
if !address.IP.Is4() {
|
||||||
|
return fmt.Errorf("IPv6 not supported for interface assignment")
|
||||||
|
}
|
||||||
|
|
||||||
|
maskBits := uint32(0xffffffff) << (32 - prefixLen)
|
||||||
|
mask := fmt.Sprintf("0x%08x", maskBits)
|
||||||
|
|
||||||
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
log.Infof("assign addr %s mask %s to %s interface", ip, mask, l.name)
|
||||||
|
|
||||||
|
|||||||
@@ -8,10 +8,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
type WGTunDevice interface {
|
type WGTunDevice interface {
|
||||||
Create(routes []string, dns string, searchDomains []string) (device.WGConfigurer, error)
|
Create(routes []string, dns string, searchDomains domain.List) (device.WGConfigurer, error)
|
||||||
Up() (*bind.UniversalUDPMuxDefault, error)
|
Up() (*bind.UniversalUDPMuxDefault, error)
|
||||||
UpdateAddr(address wgaddr.Address) error
|
UpdateAddr(address wgaddr.Address) error
|
||||||
WgAddress() wgaddr.Address
|
WgAddress() wgaddr.Address
|
||||||
|
|||||||
@@ -185,7 +185,6 @@ func (w *WGIface) SetFilter(filter device.PacketFilter) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
w.filter = filter
|
w.filter = filter
|
||||||
w.filter.SetNetwork(w.tun.WgAddress().Network)
|
|
||||||
|
|
||||||
w.tun.FilteredDevice().SetFilter(filter)
|
w.tun.FilteredDevice().SetFilter(filter)
|
||||||
return nil
|
return nil
|
||||||
@@ -212,9 +211,13 @@ func (w *WGIface) GetWGDevice() *wgdevice.Device {
|
|||||||
return w.tun.Device()
|
return w.tun.Device()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetStats returns the last handshake time, rx and tx bytes for the given peer
|
// GetStats returns the last handshake time, rx and tx bytes
|
||||||
func (w *WGIface) GetStats(peerKey string) (configurer.WGStats, error) {
|
func (w *WGIface) GetStats() (map[string]configurer.WGStats, error) {
|
||||||
return w.configurer.GetStats(peerKey)
|
return w.configurer.GetStats()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (w *WGIface) FullStats() (*configurer.Stats, error) {
|
||||||
|
return w.configurer.FullStats()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *WGIface) waitUntilRemoved() error {
|
func (w *WGIface) waitUntilRemoved() error {
|
||||||
|
|||||||
@@ -2,7 +2,11 @@
|
|||||||
|
|
||||||
package iface
|
package iface
|
||||||
|
|
||||||
import "fmt"
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
@@ -21,6 +25,6 @@ func (w *WGIface) Create() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateOnAndroid this function make sense on mobile only
|
// CreateOnAndroid this function make sense on mobile only
|
||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
|
||||||
return fmt.Errorf("this function has not implemented on non mobile")
|
return fmt.Errorf("this function has not implemented on non mobile")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,13 @@ package iface
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
// CreateOnAndroid creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
// Will reuse an existing one.
|
// Will reuse an existing one.
|
||||||
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains []string) error {
|
func (w *WGIface) CreateOnAndroid(routes []string, dns string, searchDomains domain.List) error {
|
||||||
w.mu.Lock()
|
w.mu.Lock()
|
||||||
defer w.mu.Unlock()
|
defer w.mu.Unlock()
|
||||||
|
|
||||||
|
|||||||
@@ -7,6 +7,8 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/cenkalti/backoff/v4"
|
"github.com/cenkalti/backoff/v4"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
// Create creates a new Wireguard interface, sets a given IP and brings it up.
|
||||||
@@ -36,6 +38,6 @@ func (w *WGIface) Create() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// CreateOnAndroid this function make sense on mobile only
|
// CreateOnAndroid this function make sense on mobile only
|
||||||
func (w *WGIface) CreateOnAndroid([]string, string, []string) error {
|
func (w *WGIface) CreateOnAndroid([]string, string, domain.List) error {
|
||||||
return fmt.Errorf("this function has not implemented on this platform")
|
return fmt.Errorf("this function has not implemented on this platform")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,7 +5,6 @@
|
|||||||
package mocks
|
package mocks
|
||||||
|
|
||||||
import (
|
import (
|
||||||
net "net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
@@ -90,15 +89,3 @@ func (mr *MockPacketFilterMockRecorder) RemovePacketHook(arg0 interface{}) *gomo
|
|||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "RemovePacketHook", reflect.TypeOf((*MockPacketFilter)(nil).RemovePacketHook), arg0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetNetwork mocks base method.
|
|
||||||
func (m *MockPacketFilter) SetNetwork(arg0 *net.IPNet) {
|
|
||||||
m.ctrl.T.Helper()
|
|
||||||
m.ctrl.Call(m, "SetNetwork", arg0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// SetNetwork indicates an expected call of SetNetwork.
|
|
||||||
func (mr *MockPacketFilterMockRecorder) SetNetwork(arg0 interface{}) *gomock.Call {
|
|
||||||
mr.mock.ctrl.T.Helper()
|
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetNetwork", reflect.TypeOf((*MockPacketFilter)(nil).SetNetwork), arg0)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,8 +1,6 @@
|
|||||||
package netstack
|
package netstack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
"os"
|
||||||
"strconv"
|
"strconv"
|
||||||
@@ -15,8 +13,8 @@ import (
|
|||||||
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
const EnvSkipProxy = "NB_NETSTACK_SKIP_PROXY"
|
||||||
|
|
||||||
type NetStackTun struct { //nolint:revive
|
type NetStackTun struct { //nolint:revive
|
||||||
address net.IP
|
address netip.Addr
|
||||||
dnsAddress net.IP
|
dnsAddress netip.Addr
|
||||||
mtu int
|
mtu int
|
||||||
listenAddress string
|
listenAddress string
|
||||||
|
|
||||||
@@ -24,7 +22,7 @@ type NetStackTun struct { //nolint:revive
|
|||||||
tundev tun.Device
|
tundev tun.Device
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu int) *NetStackTun {
|
func NewNetStackTun(listenAddress string, address netip.Addr, dnsAddress netip.Addr, mtu int) *NetStackTun {
|
||||||
return &NetStackTun{
|
return &NetStackTun{
|
||||||
address: address,
|
address: address,
|
||||||
dnsAddress: dnsAddress,
|
dnsAddress: dnsAddress,
|
||||||
@@ -34,19 +32,9 @@ func NewNetStackTun(listenAddress string, address net.IP, dnsAddress net.IP, mtu
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
func (t *NetStackTun) Create() (tun.Device, *netstack.Net, error) {
|
||||||
addr, ok := netip.AddrFromSlice(t.address)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, fmt.Errorf("convert address to netip.Addr: %v", t.address)
|
|
||||||
}
|
|
||||||
|
|
||||||
dnsAddr, ok := netip.AddrFromSlice(t.dnsAddress)
|
|
||||||
if !ok {
|
|
||||||
return nil, nil, fmt.Errorf("convert dns address to netip.Addr: %v", t.dnsAddress)
|
|
||||||
}
|
|
||||||
|
|
||||||
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
nsTunDev, tunNet, err := netstack.CreateNetTUN(
|
||||||
[]netip.Addr{addr.Unmap()},
|
[]netip.Addr{t.address},
|
||||||
[]netip.Addr{dnsAddr.Unmap()},
|
[]netip.Addr{t.dnsAddress},
|
||||||
t.mtu)
|
t.mtu)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, nil, err
|
return nil, nil, err
|
||||||
|
|||||||
@@ -2,28 +2,27 @@ package wgaddr
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Address WireGuard parsed address
|
// Address WireGuard parsed address
|
||||||
type Address struct {
|
type Address struct {
|
||||||
IP net.IP
|
IP netip.Addr
|
||||||
Network *net.IPNet
|
Network netip.Prefix
|
||||||
}
|
}
|
||||||
|
|
||||||
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
// ParseWGAddress parse a string ("1.2.3.4/24") address to WG Address
|
||||||
func ParseWGAddress(address string) (Address, error) {
|
func ParseWGAddress(address string) (Address, error) {
|
||||||
ip, network, err := net.ParseCIDR(address)
|
prefix, err := netip.ParsePrefix(address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return Address{}, err
|
return Address{}, err
|
||||||
}
|
}
|
||||||
return Address{
|
return Address{
|
||||||
IP: ip,
|
IP: prefix.Addr().Unmap(),
|
||||||
Network: network,
|
Network: prefix.Masked(),
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (addr Address) String() string {
|
func (addr Address) String() string {
|
||||||
maskSize, _ := addr.Network.Mask.Size()
|
return fmt.Sprintf("%s/%d", addr.IP.String(), addr.Network.Bits())
|
||||||
return fmt.Sprintf("%s/%d", addr.IP.String(), maskSize)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -24,6 +24,8 @@
|
|||||||
|
|
||||||
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
!define AUTOSTART_REG_KEY "Software\Microsoft\Windows\CurrentVersion\Run"
|
||||||
|
|
||||||
|
!define NETBIRD_DATA_DIR "$COMMONPROGRAMDATA\Netbird"
|
||||||
|
|
||||||
Unicode True
|
Unicode True
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
@@ -49,6 +51,10 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
|
!include "MUI2.nsh"
|
||||||
|
!include LogicLib.nsh
|
||||||
|
!include "nsDialogs.nsh"
|
||||||
|
|
||||||
!define MUI_ICON "${ICON}"
|
!define MUI_ICON "${ICON}"
|
||||||
!define MUI_UNICON "${ICON}"
|
!define MUI_UNICON "${ICON}"
|
||||||
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
!define MUI_WELCOMEFINISHPAGE_BITMAP "${BANNER}"
|
||||||
@@ -58,9 +64,6 @@ ShowInstDetails Show
|
|||||||
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
!define MUI_FINISHPAGE_RUN_FUNCTION "LaunchLink"
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
!include "MUI2.nsh"
|
|
||||||
!include LogicLib.nsh
|
|
||||||
|
|
||||||
!define MUI_ABORTWARNING
|
!define MUI_ABORTWARNING
|
||||||
!define MUI_UNABORTWARNING
|
!define MUI_UNABORTWARNING
|
||||||
|
|
||||||
@@ -70,13 +73,16 @@ ShowInstDetails Show
|
|||||||
|
|
||||||
!insertmacro MUI_PAGE_DIRECTORY
|
!insertmacro MUI_PAGE_DIRECTORY
|
||||||
|
|
||||||
; Custom page for autostart checkbox
|
|
||||||
Page custom AutostartPage AutostartPageLeave
|
Page custom AutostartPage AutostartPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_INSTFILES
|
!insertmacro MUI_PAGE_INSTFILES
|
||||||
|
|
||||||
!insertmacro MUI_PAGE_FINISH
|
!insertmacro MUI_PAGE_FINISH
|
||||||
|
|
||||||
|
!insertmacro MUI_UNPAGE_WELCOME
|
||||||
|
|
||||||
|
UninstPage custom un.DeleteDataPage un.DeleteDataPageLeave
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_CONFIRM
|
!insertmacro MUI_UNPAGE_CONFIRM
|
||||||
|
|
||||||
!insertmacro MUI_UNPAGE_INSTFILES
|
!insertmacro MUI_UNPAGE_INSTFILES
|
||||||
@@ -89,6 +95,10 @@ Page custom AutostartPage AutostartPageLeave
|
|||||||
Var AutostartCheckbox
|
Var AutostartCheckbox
|
||||||
Var AutostartEnabled
|
Var AutostartEnabled
|
||||||
|
|
||||||
|
; Variables for uninstall data deletion option
|
||||||
|
Var DeleteDataCheckbox
|
||||||
|
Var DeleteDataEnabled
|
||||||
|
|
||||||
######################################################################
|
######################################################################
|
||||||
|
|
||||||
; Function to create the autostart options page
|
; Function to create the autostart options page
|
||||||
@@ -104,8 +114,8 @@ Function AutostartPage
|
|||||||
|
|
||||||
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Start ${APP_NAME} UI automatically when Windows starts"
|
||||||
Pop $AutostartCheckbox
|
Pop $AutostartCheckbox
|
||||||
${NSD_Check} $AutostartCheckbox ; Default to checked
|
${NSD_Check} $AutostartCheckbox
|
||||||
StrCpy $AutostartEnabled "1" ; Default to enabled
|
StrCpy $AutostartEnabled "1"
|
||||||
|
|
||||||
nsDialogs::Show
|
nsDialogs::Show
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
@@ -115,6 +125,30 @@ Function AutostartPageLeave
|
|||||||
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
${NSD_GetState} $AutostartCheckbox $AutostartEnabled
|
||||||
FunctionEnd
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to create the uninstall data deletion page
|
||||||
|
Function un.DeleteDataPage
|
||||||
|
!insertmacro MUI_HEADER_TEXT "Uninstall Options" "Choose whether to delete ${APP_NAME} data."
|
||||||
|
|
||||||
|
nsDialogs::Create 1018
|
||||||
|
Pop $0
|
||||||
|
|
||||||
|
${If} $0 == error
|
||||||
|
Abort
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
|
${NSD_CreateCheckbox} 0 20u 100% 10u "Delete all ${APP_NAME} configuration and state data (${NETBIRD_DATA_DIR})"
|
||||||
|
Pop $DeleteDataCheckbox
|
||||||
|
${NSD_Uncheck} $DeleteDataCheckbox
|
||||||
|
StrCpy $DeleteDataEnabled "0"
|
||||||
|
|
||||||
|
nsDialogs::Show
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
|
; Function to handle leaving the data deletion page
|
||||||
|
Function un.DeleteDataPageLeave
|
||||||
|
${NSD_GetState} $DeleteDataCheckbox $DeleteDataEnabled
|
||||||
|
FunctionEnd
|
||||||
|
|
||||||
Function GetAppFromCommand
|
Function GetAppFromCommand
|
||||||
Exch $1
|
Exch $1
|
||||||
Push $2
|
Push $2
|
||||||
@@ -225,31 +259,58 @@ SectionEnd
|
|||||||
Section Uninstall
|
Section Uninstall
|
||||||
${INSTALL_TYPE}
|
${INSTALL_TYPE}
|
||||||
|
|
||||||
|
DetailPrint "Stopping Netbird service..."
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service stop'
|
||||||
|
DetailPrint "Uninstalling Netbird service..."
|
||||||
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
ExecWait '"$INSTDIR\${MAIN_APP_EXE}" service uninstall'
|
||||||
|
|
||||||
# kill ui client
|
DetailPrint "Terminating Netbird UI process..."
|
||||||
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
ExecWait `taskkill /im ${UI_APP_EXE}.exe /f`
|
||||||
|
|
||||||
; Remove autostart registry entry
|
; Remove autostart registry entry
|
||||||
|
DetailPrint "Removing autostart registry entry if exists..."
|
||||||
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
DeleteRegValue HKCU "${AUTOSTART_REG_KEY}" "${APP_NAME}"
|
||||||
|
|
||||||
|
; Handle data deletion based on checkbox
|
||||||
|
DetailPrint "Checking if user requested data deletion..."
|
||||||
|
${If} $DeleteDataEnabled == "1"
|
||||||
|
DetailPrint "User opted to delete Netbird data. Removing ${NETBIRD_DATA_DIR}..."
|
||||||
|
ClearErrors
|
||||||
|
RMDir /r "${NETBIRD_DATA_DIR}"
|
||||||
|
IfErrors 0 +2 ; If no errors, jump over the message
|
||||||
|
DetailPrint "Error deleting Netbird data directory. It might be in use or already removed."
|
||||||
|
DetailPrint "Netbird data directory removal complete."
|
||||||
|
${Else}
|
||||||
|
DetailPrint "User did not opt to delete Netbird data."
|
||||||
|
${EndIf}
|
||||||
|
|
||||||
# wait the service uninstall take unblock the executable
|
# wait the service uninstall take unblock the executable
|
||||||
|
DetailPrint "Waiting for service handle to be released..."
|
||||||
Sleep 3000
|
Sleep 3000
|
||||||
|
|
||||||
|
DetailPrint "Deleting application files..."
|
||||||
Delete "$INSTDIR\${UI_APP_EXE}"
|
Delete "$INSTDIR\${UI_APP_EXE}"
|
||||||
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
Delete "$INSTDIR\${MAIN_APP_EXE}"
|
||||||
Delete "$INSTDIR\wintun.dll"
|
Delete "$INSTDIR\wintun.dll"
|
||||||
Delete "$INSTDIR\opengl32.dll"
|
Delete "$INSTDIR\opengl32.dll"
|
||||||
|
DetailPrint "Removing application directory..."
|
||||||
RmDir /r "$INSTDIR"
|
RmDir /r "$INSTDIR"
|
||||||
|
|
||||||
|
DetailPrint "Removing shortcuts..."
|
||||||
SetShellVarContext all
|
SetShellVarContext all
|
||||||
Delete "$DESKTOP\${APP_NAME}.lnk"
|
Delete "$DESKTOP\${APP_NAME}.lnk"
|
||||||
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
Delete "$SMPROGRAMS\${APP_NAME}.lnk"
|
||||||
|
|
||||||
|
DetailPrint "Removing registry keys..."
|
||||||
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
DeleteRegKey ${REG_ROOT} "${REG_APP_PATH}"
|
||||||
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
DeleteRegKey ${REG_ROOT} "${UNINSTALL_PATH}"
|
||||||
|
DeleteRegKey ${REG_ROOT} "${UI_REG_APP_PATH}"
|
||||||
|
|
||||||
|
DetailPrint "Removing application directory from PATH..."
|
||||||
EnVar::SetHKLM
|
EnVar::SetHKLM
|
||||||
EnVar::DeleteValue "path" "$INSTDIR"
|
EnVar::DeleteValue "path" "$INSTDIR"
|
||||||
|
|
||||||
|
DetailPrint "Uninstallation finished."
|
||||||
SectionEnd
|
SectionEnd
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
|
|||||||
|
|
||||||
func GenerateRouteRuleKey(
|
func GenerateRouteRuleKey(
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination manager.Network,
|
||||||
proto manager.Protocol,
|
proto manager.Protocol,
|
||||||
sPort *manager.Port,
|
sPort *manager.Port,
|
||||||
dPort *manager.Port,
|
dPort *manager.Port,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,7 +26,7 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
|||||||
|
|
||||||
// Manager is a ACL rules manager
|
// Manager is a ACL rules manager
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||||
}
|
}
|
||||||
|
|
||||||
type protoMatch struct {
|
type protoMatch struct {
|
||||||
@@ -53,10 +54,15 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
|||||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||||
//
|
//
|
||||||
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
||||||
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
|
||||||
d.mutex.Lock()
|
d.mutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.mutex.Unlock()
|
||||||
|
|
||||||
|
if d.firewall == nil {
|
||||||
|
log.Debug("firewall manager is not supported, skipping firewall rules")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
defer func() {
|
defer func() {
|
||||||
total := 0
|
total := 0
|
||||||
@@ -68,21 +74,9 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
time.Since(start), total)
|
time.Since(start), total)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
if d.firewall == nil {
|
|
||||||
log.Debug("firewall manager is not supported, skipping firewall rules")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
d.applyPeerACLs(networkMap)
|
d.applyPeerACLs(networkMap)
|
||||||
|
|
||||||
// If we got empty rules list but management did not set the networkMap.FirewallRulesIsEmpty flag,
|
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||||
// then the mgmt server is older than the client, and we need to allow all traffic for routes
|
|
||||||
isLegacy := len(networkMap.RoutesFirewallRules) == 0 && !networkMap.RoutesFirewallRulesIsEmpty
|
|
||||||
if err := d.firewall.SetLegacyManagement(isLegacy); err != nil {
|
|
||||||
log.Errorf("failed to set legacy management flag: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
|
|
||||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -176,16 +170,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
d.peerRulesPairs = newRulePairs
|
d.peerRulesPairs = newRulePairs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
||||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
id, err := d.applyRouteACL(rule)
|
id, err := d.applyRouteACL(rule, dynamicResolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||||
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
|
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
||||||
} else {
|
} else {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
||||||
}
|
}
|
||||||
@@ -208,7 +202,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
|
||||||
if len(rule.SourceRanges) == 0 {
|
if len(rule.SourceRanges) == 0 {
|
||||||
return "", ErrSourceRangesEmpty
|
return "", ErrSourceRangesEmpty
|
||||||
}
|
}
|
||||||
@@ -222,15 +216,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
|||||||
sources = append(sources, source)
|
sources = append(sources, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
var destination netip.Prefix
|
destination, err := determineDestination(rule, dynamicResolver, sources)
|
||||||
if rule.IsDynamic {
|
|
||||||
destination = getDefault(sources[0])
|
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
destination, err = netip.ParsePrefix(rule.Destination)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("parse destination: %w", err)
|
return "", fmt.Errorf("determine destination: %w", err)
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||||
@@ -296,8 +284,10 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
if d.firewall.IsStateful() {
|
||||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
return "", nil, nil
|
||||||
|
}
|
||||||
|
// return traffic for outbound connections if firewall is stateless
|
||||||
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
@@ -580,6 +570,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
|
||||||
|
var destination firewall.Network
|
||||||
|
|
||||||
|
if rule.IsDynamic {
|
||||||
|
if dynamicResolver {
|
||||||
|
if len(rule.Domains) > 0 {
|
||||||
|
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
|
||||||
|
} else {
|
||||||
|
// isDynamic is set but no domains = outdated management server
|
||||||
|
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
|
||||||
|
destination.Prefix = getDefault(sources[0])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// client resolves DNS, we (router) don't know the destination
|
||||||
|
destination.Prefix = getDefault(sources[0])
|
||||||
|
}
|
||||||
|
return destination, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix, err := netip.ParsePrefix(rule.Destination)
|
||||||
|
if err != nil {
|
||||||
|
return destination, fmt.Errorf("parse destination: %w", err)
|
||||||
|
}
|
||||||
|
destination.Prefix = prefix
|
||||||
|
return destination, nil
|
||||||
|
}
|
||||||
|
|
||||||
func getDefault(prefix netip.Prefix) netip.Prefix {
|
func getDefault(prefix netip.Prefix) netip.Prefix {
|
||||||
if prefix.Addr().Is6() {
|
if prefix.Addr().Is6() {
|
||||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||||
|
|||||||
@@ -1,13 +1,14 @@
|
|||||||
package acl
|
package acl
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
"github.com/golang/mock/gomock"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall"
|
"github.com/netbirdio/netbird/client/firewall"
|
||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
"github.com/netbirdio/netbird/client/internal/netflow"
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
@@ -42,35 +43,31 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: network.Addr(),
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Errorf("create firewall: %v", err)
|
defer func() {
|
||||||
return
|
err = fw.Close(nil)
|
||||||
}
|
require.NoError(t, err)
|
||||||
defer func(fw manager.Manager) {
|
}()
|
||||||
_ = fw.Close(nil)
|
|
||||||
}(fw)
|
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
t.Run("apply firewall rules", func(t *testing.T) {
|
t.Run("apply firewall rules", func(t *testing.T) {
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if fw.IsStateful() {
|
||||||
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||||
return
|
} else {
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -92,14 +89,15 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
// we should have one old and one new rule in the existed rules
|
expectedRules := 2
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if fw.IsStateful() {
|
||||||
t.Errorf("firewall rules not applied")
|
expectedRules = 1 // only the inbound rule
|
||||||
return
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||||
|
|
||||||
// check that old rule was removed
|
// check that old rule was removed
|
||||||
previousCount := 0
|
previousCount := 0
|
||||||
for id := range acl.peerRulesPairs {
|
for id := range acl.peerRulesPairs {
|
||||||
@@ -107,26 +105,86 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
previousCount++
|
previousCount++
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if previousCount != 1 {
|
|
||||||
t.Errorf("old rule was not removed")
|
expectedPreviousCount := 0
|
||||||
|
if !fw.IsStateful() {
|
||||||
|
expectedPreviousCount = 1
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, expectedPreviousCount, previousCount)
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("handle default rules", func(t *testing.T) {
|
t.Run("handle default rules", func(t *testing.T) {
|
||||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = true
|
networkMap.FirewallRulesIsEmpty = true
|
||||||
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
|
acl.ApplyFiltering(networkMap, false)
|
||||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
assert.Equal(t, 0, len(acl.peerRulesPairs))
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
networkMap.FirewallRulesIsEmpty = false
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
if len(acl.peerRulesPairs) != 1 {
|
|
||||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
expectedRules := 1
|
||||||
return
|
if fw.IsStateful() {
|
||||||
|
expectedRules = 1 // only inbound allow-all rule
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDefaultManagerStateless(t *testing.T) {
|
||||||
|
// stateless currently only in userspace, so we have to disable kernel
|
||||||
|
t.Setenv("NB_WG_KERNEL_DISABLED", "true")
|
||||||
|
t.Setenv("NB_DISABLE_CONNTRACK", "true")
|
||||||
|
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.1",
|
||||||
|
Direction: mgmProto.RuleDirection_OUT,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_TCP,
|
||||||
|
Port: "80",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
PeerIP: "10.93.0.2",
|
||||||
|
Direction: mgmProto.RuleDirection_IN,
|
||||||
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
|
Protocol: mgmProto.RuleProtocol_UDP,
|
||||||
|
Port: "53",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
defer ctrl.Finish()
|
||||||
|
|
||||||
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
|
|
||||||
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
|
IP: network.Addr(),
|
||||||
|
Network: network,
|
||||||
|
}).AnyTimes()
|
||||||
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
|
require.NoError(t, err)
|
||||||
|
defer func() {
|
||||||
|
err = fw.Close(nil)
|
||||||
|
require.NoError(t, err)
|
||||||
|
}()
|
||||||
|
|
||||||
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
|
t.Run("stateless firewall creates outbound rules", func(t *testing.T) {
|
||||||
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
|
// In stateless mode, we should have both inbound and outbound rules
|
||||||
|
assert.False(t, fw.IsStateful())
|
||||||
|
assert.Equal(t, 2, len(acl.peerRulesPairs))
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -192,42 +250,19 @@ func TestDefaultManagerSquashRules(t *testing.T) {
|
|||||||
|
|
||||||
manager := &DefaultManager{}
|
manager := &DefaultManager{}
|
||||||
rules, _ := manager.squashAcceptRules(networkMap)
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
if len(rules) != 2 {
|
assert.Equal(t, 2, len(rules))
|
||||||
t.Errorf("rules should contain 2, got: %v", rules)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r := rules[0]
|
r := rules[0]
|
||||||
switch {
|
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||||
case r.PeerIP != "0.0.0.0":
|
assert.Equal(t, mgmProto.RuleDirection_IN, r.Direction)
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||||
return
|
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||||
case r.Direction != mgmProto.RuleDirection_IN:
|
|
||||||
t.Errorf("direction should be IN, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
r = rules[1]
|
r = rules[1]
|
||||||
switch {
|
assert.Equal(t, "0.0.0.0", r.PeerIP)
|
||||||
case r.PeerIP != "0.0.0.0":
|
assert.Equal(t, mgmProto.RuleDirection_OUT, r.Direction)
|
||||||
t.Errorf("IP should be 0.0.0.0, got: %v", r.PeerIP)
|
assert.Equal(t, mgmProto.RuleProtocol_ALL, r.Protocol)
|
||||||
return
|
assert.Equal(t, mgmProto.RuleAction_ACCEPT, r.Action)
|
||||||
case r.Direction != mgmProto.RuleDirection_OUT:
|
|
||||||
t.Errorf("direction should be OUT, got: %v", r.Direction)
|
|
||||||
return
|
|
||||||
case r.Protocol != mgmProto.RuleProtocol_ALL:
|
|
||||||
t.Errorf("protocol should be ALL, got: %v", r.Protocol)
|
|
||||||
return
|
|
||||||
case r.Action != mgmProto.RuleAction_ACCEPT:
|
|
||||||
t.Errorf("action should be ACCEPT, got: %v", r.Action)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
||||||
@@ -291,9 +326,8 @@ func TestDefaultManagerSquashRulesNoAffect(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager := &DefaultManager{}
|
manager := &DefaultManager{}
|
||||||
if rules, _ := manager.squashAcceptRules(networkMap); len(rules) != len(networkMap.FirewallRules) {
|
rules, _ := manager.squashAcceptRules(networkMap)
|
||||||
t.Errorf("we should get the same amount of rules as output, got %v", len(rules))
|
assert.Equal(t, len(networkMap.FirewallRules), len(rules))
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
||||||
@@ -336,33 +370,29 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
ifaceMock := mocks.NewMockIFaceMapper(ctrl)
|
||||||
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
ifaceMock.EXPECT().IsUserspaceBind().Return(true).AnyTimes()
|
||||||
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
ifaceMock.EXPECT().SetFilter(gomock.Any())
|
||||||
ip, network, err := net.ParseCIDR("172.0.0.1/32")
|
network := netip.MustParsePrefix("172.0.0.1/32")
|
||||||
if err != nil {
|
|
||||||
t.Fatalf("failed to parse IP address: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
ifaceMock.EXPECT().Name().Return("lo").AnyTimes()
|
||||||
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
ifaceMock.EXPECT().Address().Return(wgaddr.Address{
|
||||||
IP: ip,
|
IP: network.Addr(),
|
||||||
Network: network,
|
Network: network,
|
||||||
}).AnyTimes()
|
}).AnyTimes()
|
||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
require.NoError(t, err)
|
||||||
t.Errorf("create firewall: %v", err)
|
defer func() {
|
||||||
return
|
err = fw.Close(nil)
|
||||||
}
|
require.NoError(t, err)
|
||||||
defer func(fw manager.Manager) {
|
}()
|
||||||
_ = fw.Close(nil)
|
|
||||||
}(fw)
|
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 3 {
|
expectedRules := 3
|
||||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
if fw.IsStateful() {
|
||||||
return
|
expectedRules = 3 // 2 inbound rules + SSH rule
|
||||||
}
|
}
|
||||||
|
assert.Equal(t, expectedRules, len(acl.peerRulesPairs))
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -64,13 +64,8 @@ func (t TokenInfo) GetTokenToUse() string {
|
|||||||
// and if that also fails, the authentication process is deemed unsuccessful
|
// and if that also fails, the authentication process is deemed unsuccessful
|
||||||
//
|
//
|
||||||
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
// On Linux distros without desktop environment support, it only tries to initialize the Device Code Flow
|
||||||
func NewOAuthFlow(ctx context.Context, config *internal.Config, isLinuxDesktopClient bool) (OAuthFlow, error) {
|
func NewOAuthFlow(ctx context.Context, config *internal.Config, isUnixDesktopClient bool) (OAuthFlow, error) {
|
||||||
if runtime.GOOS == "linux" && !isLinuxDesktopClient {
|
if (runtime.GOOS == "linux" || runtime.GOOS == "freebsd") && !isUnixDesktopClient {
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
|
||||||
}
|
|
||||||
|
|
||||||
// On FreeBSD we currently do not support desktop environments and offer only Device Code Flow (#2384)
|
|
||||||
if runtime.GOOS == "freebsd" {
|
|
||||||
return authenticateWithDeviceCodeFlow(ctx, config)
|
return authenticateWithDeviceCodeFlow(ctx, config)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -94,12 +94,22 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
p.codeVerifier = codeVerifier
|
p.codeVerifier = codeVerifier
|
||||||
|
|
||||||
codeChallenge := createCodeChallenge(codeVerifier)
|
codeChallenge := createCodeChallenge(codeVerifier)
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(
|
|
||||||
state,
|
params := []oauth2.AuthCodeOption{
|
||||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
)
|
}
|
||||||
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
|
if p.providerConfig.LoginFlag.IsPromptLogin() {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
|
}
|
||||||
|
if p.providerConfig.LoginFlag.IsMaxAge0Login() {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("max_age", "0"))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||||
|
|
||||||
return AuthFlowInfo{
|
return AuthFlowInfo{
|
||||||
VerificationURIComplete: authURL,
|
VerificationURIComplete: authURL,
|
||||||
|
|||||||
71
client/internal/auth/pkce_flow_test.go
Normal file
71
client/internal/auth/pkce_flow_test.go
Normal file
@@ -0,0 +1,71 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
mgm "github.com/netbirdio/netbird/management/client/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPromptLogin(t *testing.T) {
|
||||||
|
const (
|
||||||
|
promptLogin = "prompt=login"
|
||||||
|
maxAge0 = "max_age=0"
|
||||||
|
)
|
||||||
|
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
loginFlag mgm.LoginFlag
|
||||||
|
disablePromptLogin bool
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Prompt login",
|
||||||
|
loginFlag: mgm.LoginFlagPrompt,
|
||||||
|
expect: promptLogin,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Max age 0 login",
|
||||||
|
loginFlag: mgm.LoginFlagMaxAge0,
|
||||||
|
expect: maxAge0,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Disable prompt login",
|
||||||
|
loginFlag: mgm.LoginFlagPrompt,
|
||||||
|
disablePromptLogin: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
config := internal.PKCEAuthProviderConfig{
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Audience: "test-audience",
|
||||||
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
Scope: "openid email profile",
|
||||||
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
|
UseIDToken: true,
|
||||||
|
LoginFlag: tc.loginFlag,
|
||||||
|
}
|
||||||
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
|
||||||
|
}
|
||||||
|
authInfo, err := pkce.RequestAuthInfo(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !tc.disablePromptLogin {
|
||||||
|
require.Contains(t, authInfo.VerificationURIComplete, tc.expect)
|
||||||
|
} else {
|
||||||
|
require.Contains(t, authInfo.VerificationURIComplete, promptLogin)
|
||||||
|
require.NotContains(t, authInfo.VerificationURIComplete, maxAge0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -68,12 +68,14 @@ type ConfigInput struct {
|
|||||||
DisableServerRoutes *bool
|
DisableServerRoutes *bool
|
||||||
DisableDNS *bool
|
DisableDNS *bool
|
||||||
DisableFirewall *bool
|
DisableFirewall *bool
|
||||||
|
|
||||||
BlockLANAccess *bool
|
BlockLANAccess *bool
|
||||||
|
BlockInbound *bool
|
||||||
|
|
||||||
DisableNotifications *bool
|
DisableNotifications *bool
|
||||||
|
|
||||||
DNSLabels domain.List
|
DNSLabels domain.List
|
||||||
|
|
||||||
|
LazyConnectionEnabled *bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// Config Configuration type
|
// Config Configuration type
|
||||||
@@ -96,8 +98,8 @@ type Config struct {
|
|||||||
DisableServerRoutes bool
|
DisableServerRoutes bool
|
||||||
DisableDNS bool
|
DisableDNS bool
|
||||||
DisableFirewall bool
|
DisableFirewall bool
|
||||||
|
|
||||||
BlockLANAccess bool
|
BlockLANAccess bool
|
||||||
|
BlockInbound bool
|
||||||
|
|
||||||
DisableNotifications *bool
|
DisableNotifications *bool
|
||||||
|
|
||||||
@@ -138,6 +140,8 @@ type Config struct {
|
|||||||
ClientCertKeyPath string
|
ClientCertKeyPath string
|
||||||
|
|
||||||
ClientCertKeyPair *tls.Certificate `json:"-"`
|
ClientCertKeyPair *tls.Certificate `json:"-"`
|
||||||
|
|
||||||
|
LazyConnectionEnabled bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
// ReadConfig read config file and return with Config. If it is not exists create a new with default values
|
||||||
@@ -479,6 +483,16 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.BlockInbound != nil && *input.BlockInbound != config.BlockInbound {
|
||||||
|
if *input.BlockInbound {
|
||||||
|
log.Infof("blocking inbound connections")
|
||||||
|
} else {
|
||||||
|
log.Infof("allowing inbound connections")
|
||||||
|
}
|
||||||
|
config.BlockInbound = *input.BlockInbound
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
if input.DisableNotifications != nil && input.DisableNotifications != config.DisableNotifications {
|
||||||
if *input.DisableNotifications {
|
if *input.DisableNotifications {
|
||||||
log.Infof("disabling notifications")
|
log.Infof("disabling notifications")
|
||||||
@@ -524,6 +538,12 @@ func (config *Config) apply(input ConfigInput) (updated bool, err error) {
|
|||||||
updated = true
|
updated = true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if input.LazyConnectionEnabled != nil && *input.LazyConnectionEnabled != config.LazyConnectionEnabled {
|
||||||
|
log.Infof("switching lazy connection to %t", *input.LazyConnectionEnabled)
|
||||||
|
config.LazyConnectionEnabled = *input.LazyConnectionEnabled
|
||||||
|
updated = true
|
||||||
|
}
|
||||||
|
|
||||||
return updated, nil
|
return updated, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
303
client/internal/conn_mgr.go
Normal file
303
client/internal/conn_mgr.go
Normal file
@@ -0,0 +1,303 @@
|
|||||||
|
package internal
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/lazyconn/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer/dispatcher"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peerstore"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ConnMgr coordinates both lazy connections (established on-demand) and permanent peer connections.
|
||||||
|
//
|
||||||
|
// The connection manager is responsible for:
|
||||||
|
// - Managing lazy connections via the lazyConnManager
|
||||||
|
// - Maintaining a list of excluded peers that should always have permanent connections
|
||||||
|
// - Handling connection establishment based on peer signaling
|
||||||
|
//
|
||||||
|
// The implementation is not thread-safe; it is protected by engine.syncMsgMux.
|
||||||
|
type ConnMgr struct {
|
||||||
|
peerStore *peerstore.Store
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
iface lazyconn.WGIface
|
||||||
|
dispatcher *dispatcher.ConnectionDispatcher
|
||||||
|
enabledLocally bool
|
||||||
|
|
||||||
|
lazyConnMgr *manager.Manager
|
||||||
|
|
||||||
|
wg sync.WaitGroup
|
||||||
|
ctx context.Context
|
||||||
|
ctxCancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewConnMgr(engineConfig *EngineConfig, statusRecorder *peer.Status, peerStore *peerstore.Store, iface lazyconn.WGIface, dispatcher *dispatcher.ConnectionDispatcher) *ConnMgr {
|
||||||
|
e := &ConnMgr{
|
||||||
|
peerStore: peerStore,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
|
iface: iface,
|
||||||
|
dispatcher: dispatcher,
|
||||||
|
}
|
||||||
|
if engineConfig.LazyConnectionEnabled || lazyconn.IsLazyConnEnabledByEnv() {
|
||||||
|
e.enabledLocally = true
|
||||||
|
}
|
||||||
|
return e
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start initializes the connection manager and starts the lazy connection manager if enabled by env var or cmd line option.
|
||||||
|
func (e *ConnMgr) Start(ctx context.Context) {
|
||||||
|
if e.lazyConnMgr != nil {
|
||||||
|
log.Errorf("lazy connection manager is already started")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.enabledLocally {
|
||||||
|
log.Infof("lazy connection manager is disabled")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.initLazyManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdatedRemoteFeatureFlag is called when the remote feature flag is updated.
|
||||||
|
// If enabled, it initializes the lazy connection manager and start it. Do not need to call Start() again.
|
||||||
|
// If disabled, then it closes the lazy connection manager and open the connections to all peers.
|
||||||
|
func (e *ConnMgr) UpdatedRemoteFeatureFlag(ctx context.Context, enabled bool) error {
|
||||||
|
// do not disable lazy connection manager if it was enabled by env var
|
||||||
|
if e.enabledLocally {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if enabled {
|
||||||
|
// if the lazy connection manager is already started, do not start it again
|
||||||
|
if e.lazyConnMgr != nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("lazy connection manager is enabled by management feature flag")
|
||||||
|
e.initLazyManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(true)
|
||||||
|
return e.addPeersToLazyConnManager(ctx)
|
||||||
|
} else {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
log.Infof("lazy connection manager is disabled by management feature flag")
|
||||||
|
e.closeManager(ctx)
|
||||||
|
e.statusRecorder.UpdateLazyConnection(false)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetExcludeList sets the list of peer IDs that should always have permanent connections.
|
||||||
|
func (e *ConnMgr) SetExcludeList(peerIDs map[string]bool) {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
excludedPeers := make([]lazyconn.PeerConfig, 0, len(peerIDs))
|
||||||
|
|
||||||
|
for peerID := range peerIDs {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: peerConn.ConnID(),
|
||||||
|
Log: peerConn.Log,
|
||||||
|
}
|
||||||
|
excludedPeers = append(excludedPeers, lazyPeerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
added := e.lazyConnMgr.ExcludePeer(e.ctx, excludedPeers)
|
||||||
|
for _, peerID := range added {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
// if the peer not exist in the store, it means that the engine will call the AddPeerConn in next step
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
peerConn.Log.Infof("peer has been added to lazy connection exclude list, opening permanent connection")
|
||||||
|
if err := peerConn.Open(e.ctx); err != nil {
|
||||||
|
peerConn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) AddPeerConn(ctx context.Context, peerKey string, conn *peer.Conn) (exists bool) {
|
||||||
|
if success := e.peerStore.AddPeerConn(peerKey, conn); !success {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if !lazyconn.IsSupported(conn.AgentVersionString()) {
|
||||||
|
conn.Log.Warnf("peer does not support lazy connection (%s), open permanent connection", conn.AgentVersionString())
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerKey,
|
||||||
|
AllowedIPs: conn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: conn.ConnID(),
|
||||||
|
Log: conn.Log,
|
||||||
|
}
|
||||||
|
excluded, err := e.lazyConnMgr.AddPeer(lazyPeerCfg)
|
||||||
|
if err != nil {
|
||||||
|
conn.Log.Errorf("failed to add peer to lazyconn manager: %v", err)
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if excluded {
|
||||||
|
conn.Log.Infof("peer is on lazy conn manager exclude list, opening connection")
|
||||||
|
if err := conn.Open(ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Log.Infof("peer added to lazy conn manager")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) RemovePeerConn(peerKey string) {
|
||||||
|
conn, ok := e.peerStore.Remove(peerKey)
|
||||||
|
if !ok {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.lazyConnMgr.RemovePeer(peerKey)
|
||||||
|
conn.Log.Infof("removed peer from lazy conn manager")
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) OnSignalMsg(ctx context.Context, peerKey string) (*peer.Conn, bool) {
|
||||||
|
conn, ok := e.peerStore.PeerConn(peerKey)
|
||||||
|
if !ok {
|
||||||
|
return nil, false
|
||||||
|
}
|
||||||
|
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return conn, true
|
||||||
|
}
|
||||||
|
|
||||||
|
if found := e.lazyConnMgr.ActivatePeer(ctx, peerKey); found {
|
||||||
|
conn.Log.Infof("activated peer from inactive state")
|
||||||
|
if err := conn.Open(e.ctx); err != nil {
|
||||||
|
conn.Log.Errorf("failed to open connection: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return conn, true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) Close() {
|
||||||
|
if !e.isStartedWithLazyMgr() {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.ctxCancel()
|
||||||
|
e.wg.Wait()
|
||||||
|
e.lazyConnMgr = nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) initLazyManager(parentCtx context.Context) {
|
||||||
|
cfg := manager.Config{
|
||||||
|
InactivityThreshold: inactivityThresholdEnv(),
|
||||||
|
}
|
||||||
|
e.lazyConnMgr = manager.NewManager(cfg, e.peerStore, e.iface, e.dispatcher)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(parentCtx)
|
||||||
|
e.ctx = ctx
|
||||||
|
e.ctxCancel = cancel
|
||||||
|
|
||||||
|
e.wg.Add(1)
|
||||||
|
go func() {
|
||||||
|
defer e.wg.Done()
|
||||||
|
e.lazyConnMgr.Start(ctx)
|
||||||
|
}()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) addPeersToLazyConnManager(ctx context.Context) error {
|
||||||
|
peers := e.peerStore.PeersPubKey()
|
||||||
|
lazyPeerCfgs := make([]lazyconn.PeerConfig, 0, len(peers))
|
||||||
|
for _, peerID := range peers {
|
||||||
|
var peerConn *peer.Conn
|
||||||
|
var exists bool
|
||||||
|
if peerConn, exists = e.peerStore.PeerConn(peerID); !exists {
|
||||||
|
log.Warnf("failed to find peer conn for peerID: %s", peerID)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
lazyPeerCfg := lazyconn.PeerConfig{
|
||||||
|
PublicKey: peerID,
|
||||||
|
AllowedIPs: peerConn.WgConfig().AllowedIps,
|
||||||
|
PeerConnID: peerConn.ConnID(),
|
||||||
|
Log: peerConn.Log,
|
||||||
|
}
|
||||||
|
lazyPeerCfgs = append(lazyPeerCfgs, lazyPeerCfg)
|
||||||
|
}
|
||||||
|
|
||||||
|
return e.lazyConnMgr.AddActivePeers(ctx, lazyPeerCfgs)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) closeManager(ctx context.Context) {
|
||||||
|
if e.lazyConnMgr == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
e.ctxCancel()
|
||||||
|
e.wg.Wait()
|
||||||
|
e.lazyConnMgr = nil
|
||||||
|
|
||||||
|
for _, peerID := range e.peerStore.PeersPubKey() {
|
||||||
|
e.peerStore.PeerConnOpen(ctx, peerID)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *ConnMgr) isStartedWithLazyMgr() bool {
|
||||||
|
return e.lazyConnMgr != nil && e.ctxCancel != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func inactivityThresholdEnv() *time.Duration {
|
||||||
|
envValue := os.Getenv(lazyconn.EnvInactivityThreshold)
|
||||||
|
if envValue == "" {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
parsedMinutes, err := strconv.Atoi(envValue)
|
||||||
|
if err != nil || parsedMinutes <= 0 {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
d := time.Duration(parsedMinutes) * time.Minute
|
||||||
|
return &d
|
||||||
|
}
|
||||||
@@ -349,6 +349,25 @@ func (c *ConnectClient) Engine() *Engine {
|
|||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLatestNetworkMap returns the latest network map from the engine.
|
||||||
|
func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||||
|
engine := c.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, errors.New("engine is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
networkMap, err := engine.GetLatestNetworkMap()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get latest network map: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if networkMap == nil {
|
||||||
|
return nil, errors.New("network map is not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return networkMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Status returns the current client status
|
// Status returns the current client status
|
||||||
func (c *ConnectClient) Status() StatusType {
|
func (c *ConnectClient) Status() StatusType {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
@@ -417,11 +436,13 @@ func createEngineConfig(key wgtypes.Key, config *Config, peerConfig *mgmProto.Pe
|
|||||||
DNSRouteInterval: config.DNSRouteInterval,
|
DNSRouteInterval: config.DNSRouteInterval,
|
||||||
|
|
||||||
DisableClientRoutes: config.DisableClientRoutes,
|
DisableClientRoutes: config.DisableClientRoutes,
|
||||||
DisableServerRoutes: config.DisableServerRoutes,
|
DisableServerRoutes: config.DisableServerRoutes || config.BlockInbound,
|
||||||
DisableDNS: config.DisableDNS,
|
DisableDNS: config.DisableDNS,
|
||||||
DisableFirewall: config.DisableFirewall,
|
DisableFirewall: config.DisableFirewall,
|
||||||
|
|
||||||
BlockLANAccess: config.BlockLANAccess,
|
BlockLANAccess: config.BlockLANAccess,
|
||||||
|
BlockInbound: config.BlockInbound,
|
||||||
|
|
||||||
|
LazyConnectionEnabled: config.LazyConnectionEnabled,
|
||||||
}
|
}
|
||||||
|
|
||||||
if config.PreSharedKey != "" {
|
if config.PreSharedKey != "" {
|
||||||
@@ -462,7 +483,7 @@ func connectToSignal(ctx context.Context, wtConfig *mgmProto.NetbirdConfig, ourP
|
|||||||
return signalClient, nil
|
return signalClient, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// loginToManagement creates Management Services client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
// loginToManagement creates Management ServiceDependencies client, establishes a connection, logs-in and gets a global Netbird config (signal, turn, stun hosts, etc)
|
||||||
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
func loginToManagement(ctx context.Context, client mgm.Client, pubSSHKey []byte, config *Config) (*mgmProto.LoginResponse, error) {
|
||||||
|
|
||||||
serverPublicKey, err := client.GetServerPublicKey()
|
serverPublicKey, err := client.GetServerPublicKey()
|
||||||
|
|||||||
1112
client/internal/debug/debug.go
Normal file
1112
client/internal/debug/debug.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,49 +1,130 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android
|
||||||
|
|
||||||
package server
|
package debug
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/zip"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
|
"context"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"os"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxLogEntries = 100000
|
||||||
|
maxLogAge = 7 * 24 * time.Hour // Last 7 days
|
||||||
|
)
|
||||||
|
|
||||||
|
// trySystemdLogFallback attempts to get logs from systemd journal as fallback
|
||||||
|
func (g *BundleGenerator) trySystemdLogFallback() error {
|
||||||
|
log.Debug("Attempting to collect systemd journal logs")
|
||||||
|
|
||||||
|
serviceName := getServiceName()
|
||||||
|
journalLogs, err := getSystemdLogs(serviceName)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get systemd logs for %s: %w", serviceName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if strings.Contains(journalLogs, "No recent log entries found") {
|
||||||
|
log.Debug("No recent log entries found in systemd journal")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if g.anonymize {
|
||||||
|
journalLogs = g.anonymizer.AnonymizeString(journalLogs)
|
||||||
|
}
|
||||||
|
|
||||||
|
logReader := strings.NewReader(journalLogs)
|
||||||
|
fileName := fmt.Sprintf("systemd-%s.log", serviceName)
|
||||||
|
if err := g.addFileToZip(logReader, fileName); err != nil {
|
||||||
|
return fmt.Errorf("add systemd logs to bundle: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Added systemd journal logs for %s to debug bundle", serviceName)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// getServiceName gets the service name from environment or defaults to netbird
|
||||||
|
func getServiceName() string {
|
||||||
|
if unitName := os.Getenv("SYSTEMD_UNIT"); unitName != "" {
|
||||||
|
log.Debugf("Detected SYSTEMD_UNIT environment variable: %s", unitName)
|
||||||
|
return unitName
|
||||||
|
}
|
||||||
|
|
||||||
|
return "netbird"
|
||||||
|
}
|
||||||
|
|
||||||
|
// getSystemdLogs retrieves logs from systemd journal for a specific service using journalctl
|
||||||
|
func getSystemdLogs(serviceName string) (string, error) {
|
||||||
|
args := []string{
|
||||||
|
"-u", fmt.Sprintf("%s.service", serviceName),
|
||||||
|
"--since", fmt.Sprintf("-%s", maxLogAge.String()),
|
||||||
|
"--lines", fmt.Sprintf("%d", maxLogEntries),
|
||||||
|
"--no-pager",
|
||||||
|
"--output", "short-iso",
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
cmd := exec.CommandContext(ctx, "journalctl", args...)
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
if errors.Is(ctx.Err(), context.DeadlineExceeded) {
|
||||||
|
return "", fmt.Errorf("journalctl command timed out after 30 seconds")
|
||||||
|
}
|
||||||
|
if strings.Contains(err.Error(), "executable file not found") {
|
||||||
|
return "", fmt.Errorf("journalctl command not found: %w", err)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("execute journalctl: %w (stderr: %s)", err, stderr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
logs := stdout.String()
|
||||||
|
if strings.TrimSpace(logs) == "" {
|
||||||
|
return "No recent log entries found in systemd journal", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
header := fmt.Sprintf("=== Systemd Journal Logs for %s.service (last %d entries, max %s) ===\n",
|
||||||
|
serviceName, maxLogEntries, maxLogAge.String())
|
||||||
|
|
||||||
|
return header + logs, nil
|
||||||
|
}
|
||||||
|
|
||||||
// addFirewallRules collects and adds firewall rules to the archive
|
// addFirewallRules collects and adds firewall rules to the archive
|
||||||
func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
func (g *BundleGenerator) addFirewallRules() error {
|
||||||
log.Info("Collecting firewall rules")
|
log.Info("Collecting firewall rules")
|
||||||
// Collect and add iptables rules
|
|
||||||
iptablesRules, err := collectIPTablesRules()
|
iptablesRules, err := collectIPTablesRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if req.GetAnonymize() {
|
if g.anonymize {
|
||||||
iptablesRules = anonymizer.AnonymizeString(iptablesRules)
|
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||||
}
|
}
|
||||||
if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect and add nftables rules
|
|
||||||
nftablesRules, err := collectNFTablesRules()
|
nftablesRules, err := collectNFTablesRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to collect nftables rules: %v", err)
|
log.Warnf("Failed to collect nftables rules: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if req.GetAnonymize() {
|
if g.anonymize {
|
||||||
nftablesRules = anonymizer.AnonymizeString(nftablesRules)
|
nftablesRules = g.anonymizer.AnonymizeString(nftablesRules)
|
||||||
}
|
}
|
||||||
if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
if err := g.addFileToZip(strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
||||||
log.Warnf("Failed to add nftables rules to bundle: %v", err)
|
log.Warnf("Failed to add nftables rules to bundle: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -65,16 +146,23 @@ func collectIPTablesRules() (string, error) {
|
|||||||
builder.WriteString("\n")
|
builder.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then get verbose statistics for each table
|
// Collect ipset information
|
||||||
|
ipsetOutput, err := collectIPSets()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to collect ipset information: %v", err)
|
||||||
|
} else {
|
||||||
|
builder.WriteString("=== ipset list output ===\n")
|
||||||
|
builder.WriteString(ipsetOutput)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||||
|
|
||||||
// Get list of tables
|
|
||||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||||
|
|
||||||
// Get verbose statistics for the entire table
|
|
||||||
stats, err := getTableStatistics(table)
|
stats, err := getTableStatistics(table)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||||
@@ -87,6 +175,28 @@ func collectIPTablesRules() (string, error) {
|
|||||||
return builder.String(), nil
|
return builder.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// collectIPSets collects information about ipsets
|
||||||
|
func collectIPSets() (string, error) {
|
||||||
|
cmd := exec.Command("ipset", "list")
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "executable file not found") {
|
||||||
|
return "", fmt.Errorf("ipset command not found: %w", err)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
ipsets := stdout.String()
|
||||||
|
if strings.TrimSpace(ipsets) == "" {
|
||||||
|
return "No ipsets found", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipsets, nil
|
||||||
|
}
|
||||||
|
|
||||||
// collectIPTablesSave uses iptables-save to get rule definitions
|
// collectIPTablesSave uses iptables-save to get rule definitions
|
||||||
func collectIPTablesSave() (string, error) {
|
func collectIPTablesSave() (string, error) {
|
||||||
cmd := exec.Command("iptables-save")
|
cmd := exec.Command("iptables-save")
|
||||||
@@ -182,12 +292,10 @@ func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format chains
|
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
formatChain(conn, table, chain, &builder)
|
formatChain(conn, table, chain, &builder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format sets
|
|
||||||
if sets, err := conn.GetSets(table); err != nil {
|
if sets, err := conn.GetSets(table); err != nil {
|
||||||
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
|
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
|
||||||
} else if len(sets) > 0 {
|
} else if len(sets) > 0 {
|
||||||
@@ -460,7 +568,7 @@ func formatExpr(exp expr.Any) string {
|
|||||||
case *expr.Fib:
|
case *expr.Fib:
|
||||||
return formatFib(e)
|
return formatFib(e)
|
||||||
case *expr.Target:
|
case *expr.Target:
|
||||||
return fmt.Sprintf("jump %s", e.Name) // Properly format jump targets
|
return fmt.Sprintf("jump %s", e.Name)
|
||||||
case *expr.Immediate:
|
case *expr.Immediate:
|
||||||
if e.Register == 1 {
|
if e.Register == 1 {
|
||||||
return formatImmediateData(e.Data)
|
return formatImmediateData(e.Data)
|
||||||
7
client/internal/debug/debug_mobile.go
Normal file
7
client/internal/debug/debug_mobile.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build ios || android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addRoutes() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
14
client/internal/debug/debug_nonlinux.go
Normal file
14
client/internal/debug/debug_nonlinux.go
Normal file
@@ -0,0 +1,14 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
// collectFirewallRules returns nothing on non-linux systems
|
||||||
|
func (g *BundleGenerator) addFirewallRules() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) trySystemdLogFallback() error {
|
||||||
|
// Systemd is only available on Linux
|
||||||
|
// TODO: Add BSD support
|
||||||
|
return nil
|
||||||
|
}
|
||||||
25
client/internal/debug/debug_nonmobile.go
Normal file
25
client/internal/debug/debug_nonmobile.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addRoutes() error {
|
||||||
|
routes, err := systemops.GetRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: get routes including nexthop
|
||||||
|
routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
|
||||||
|
routesReader := strings.NewReader(routesContent)
|
||||||
|
if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add routes file to zip: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
543
client/internal/debug/debug_test.go
Normal file
543
client/internal/debug/debug_test.go
Normal file
@@ -0,0 +1,543 @@
|
|||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"encoding/json"
|
||||||
|
"net"
|
||||||
|
"strings"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/anonymize"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestAnonymizeStateFile(t *testing.T) {
|
||||||
|
testState := map[string]json.RawMessage{
|
||||||
|
"null_state": json.RawMessage("null"),
|
||||||
|
"test_state": mustMarshal(map[string]any{
|
||||||
|
// Test simple fields
|
||||||
|
"public_ip": "203.0.113.1",
|
||||||
|
"private_ip": "192.168.1.1",
|
||||||
|
"protected_ip": "100.64.0.1",
|
||||||
|
"well_known_ip": "8.8.8.8",
|
||||||
|
"ipv6_addr": "2001:db8::1",
|
||||||
|
"private_ipv6": "fd00::1",
|
||||||
|
"domain": "test.example.com",
|
||||||
|
"uri": "stun:stun.example.com:3478",
|
||||||
|
"uri_with_ip": "turn:203.0.113.1:3478",
|
||||||
|
"netbird_domain": "device.netbird.cloud",
|
||||||
|
|
||||||
|
// Test CIDR ranges
|
||||||
|
"public_cidr": "203.0.113.0/24",
|
||||||
|
"private_cidr": "192.168.0.0/16",
|
||||||
|
"protected_cidr": "100.64.0.0/10",
|
||||||
|
"ipv6_cidr": "2001:db8::/32",
|
||||||
|
"private_ipv6_cidr": "fd00::/8",
|
||||||
|
|
||||||
|
// Test nested structures
|
||||||
|
"nested": map[string]any{
|
||||||
|
"ip": "203.0.113.2",
|
||||||
|
"domain": "nested.example.com",
|
||||||
|
"more_nest": map[string]any{
|
||||||
|
"ip": "203.0.113.3",
|
||||||
|
"domain": "deep.example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// Test arrays
|
||||||
|
"string_array": []any{
|
||||||
|
"203.0.113.4",
|
||||||
|
"test1.example.com",
|
||||||
|
"test2.example.com",
|
||||||
|
},
|
||||||
|
"object_array": []any{
|
||||||
|
map[string]any{
|
||||||
|
"ip": "203.0.113.5",
|
||||||
|
"domain": "array1.example.com",
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"ip": "203.0.113.6",
|
||||||
|
"domain": "array2.example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
|
||||||
|
// Test multiple occurrences of same value
|
||||||
|
"duplicate_ip": "203.0.113.1", // Same as public_ip
|
||||||
|
"duplicate_domain": "test.example.com", // Same as domain
|
||||||
|
|
||||||
|
// Test URIs with various schemes
|
||||||
|
"stun_uri": "stun:stun.example.com:3478",
|
||||||
|
"turns_uri": "turns:turns.example.com:5349",
|
||||||
|
"http_uri": "http://web.example.com:80",
|
||||||
|
"https_uri": "https://secure.example.com:443",
|
||||||
|
|
||||||
|
// Test strings that might look like IPs but aren't
|
||||||
|
"not_ip": "300.300.300.300",
|
||||||
|
"partial_ip": "192.168",
|
||||||
|
"ip_like_string": "1234.5678",
|
||||||
|
|
||||||
|
// Test mixed content strings
|
||||||
|
"mixed_content": "Server at 203.0.113.1 (test.example.com) on port 80",
|
||||||
|
|
||||||
|
// Test empty and special values
|
||||||
|
"empty_string": "",
|
||||||
|
"null_value": nil,
|
||||||
|
"numeric_value": 42,
|
||||||
|
"boolean_value": true,
|
||||||
|
}),
|
||||||
|
"route_state": mustMarshal(map[string]any{
|
||||||
|
"routes": []any{
|
||||||
|
map[string]any{
|
||||||
|
"network": "203.0.113.0/24",
|
||||||
|
"gateway": "203.0.113.1",
|
||||||
|
"domains": []any{
|
||||||
|
"route1.example.com",
|
||||||
|
"route2.example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
map[string]any{
|
||||||
|
"network": "2001:db8::/32",
|
||||||
|
"gateway": "2001:db8::1",
|
||||||
|
"domains": []any{
|
||||||
|
"route3.example.com",
|
||||||
|
"route4.example.com",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
// Test map with IP/CIDR keys
|
||||||
|
"refCountMap": map[string]any{
|
||||||
|
"203.0.113.1/32": map[string]any{
|
||||||
|
"Count": 1,
|
||||||
|
"Out": map[string]any{
|
||||||
|
"IP": "192.168.0.1",
|
||||||
|
"Intf": map[string]any{
|
||||||
|
"Name": "eth0",
|
||||||
|
"Index": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"2001:db8::1/128": map[string]any{
|
||||||
|
"Count": 1,
|
||||||
|
"Out": map[string]any{
|
||||||
|
"IP": "fe80::1",
|
||||||
|
"Intf": map[string]any{
|
||||||
|
"Name": "eth0",
|
||||||
|
"Index": 1,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
"10.0.0.1/32": map[string]any{ // private IP should remain unchanged
|
||||||
|
"Count": 1,
|
||||||
|
"Out": map[string]any{
|
||||||
|
"IP": "192.168.0.1",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}),
|
||||||
|
}
|
||||||
|
|
||||||
|
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||||
|
|
||||||
|
// Pre-seed the domains we need to verify in the test assertions
|
||||||
|
anonymizer.AnonymizeDomain("test.example.com")
|
||||||
|
anonymizer.AnonymizeDomain("nested.example.com")
|
||||||
|
anonymizer.AnonymizeDomain("deep.example.com")
|
||||||
|
anonymizer.AnonymizeDomain("array1.example.com")
|
||||||
|
|
||||||
|
err := anonymizeStateFile(&testState, anonymizer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Helper function to unmarshal and get nested values
|
||||||
|
var state map[string]any
|
||||||
|
err = json.Unmarshal(testState["test_state"], &state)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test null state remains unchanged
|
||||||
|
require.Equal(t, "null", string(testState["null_state"]))
|
||||||
|
|
||||||
|
// Basic assertions
|
||||||
|
assert.NotEqual(t, "203.0.113.1", state["public_ip"])
|
||||||
|
assert.Equal(t, "192.168.1.1", state["private_ip"]) // Private IP unchanged
|
||||||
|
assert.Equal(t, "100.64.0.1", state["protected_ip"]) // Protected IP unchanged
|
||||||
|
assert.Equal(t, "8.8.8.8", state["well_known_ip"]) // Well-known IP unchanged
|
||||||
|
assert.NotEqual(t, "2001:db8::1", state["ipv6_addr"])
|
||||||
|
assert.Equal(t, "fd00::1", state["private_ipv6"]) // Private IPv6 unchanged
|
||||||
|
assert.NotEqual(t, "test.example.com", state["domain"])
|
||||||
|
assert.True(t, strings.HasSuffix(state["domain"].(string), ".domain"))
|
||||||
|
assert.Equal(t, "device.netbird.cloud", state["netbird_domain"]) // Netbird domain unchanged
|
||||||
|
|
||||||
|
// CIDR ranges
|
||||||
|
assert.NotEqual(t, "203.0.113.0/24", state["public_cidr"])
|
||||||
|
assert.Contains(t, state["public_cidr"], "/24") // Prefix preserved
|
||||||
|
assert.Equal(t, "192.168.0.0/16", state["private_cidr"]) // Private CIDR unchanged
|
||||||
|
assert.Equal(t, "100.64.0.0/10", state["protected_cidr"]) // Protected CIDR unchanged
|
||||||
|
assert.NotEqual(t, "2001:db8::/32", state["ipv6_cidr"])
|
||||||
|
assert.Contains(t, state["ipv6_cidr"], "/32") // IPv6 prefix preserved
|
||||||
|
|
||||||
|
// Nested structures
|
||||||
|
nested := state["nested"].(map[string]any)
|
||||||
|
assert.NotEqual(t, "203.0.113.2", nested["ip"])
|
||||||
|
assert.NotEqual(t, "nested.example.com", nested["domain"])
|
||||||
|
moreNest := nested["more_nest"].(map[string]any)
|
||||||
|
assert.NotEqual(t, "203.0.113.3", moreNest["ip"])
|
||||||
|
assert.NotEqual(t, "deep.example.com", moreNest["domain"])
|
||||||
|
|
||||||
|
// Arrays
|
||||||
|
strArray := state["string_array"].([]any)
|
||||||
|
assert.NotEqual(t, "203.0.113.4", strArray[0])
|
||||||
|
assert.NotEqual(t, "test1.example.com", strArray[1])
|
||||||
|
assert.True(t, strings.HasSuffix(strArray[1].(string), ".domain"))
|
||||||
|
|
||||||
|
objArray := state["object_array"].([]any)
|
||||||
|
firstObj := objArray[0].(map[string]any)
|
||||||
|
assert.NotEqual(t, "203.0.113.5", firstObj["ip"])
|
||||||
|
assert.NotEqual(t, "array1.example.com", firstObj["domain"])
|
||||||
|
|
||||||
|
// Duplicate values should be anonymized consistently
|
||||||
|
assert.Equal(t, state["public_ip"], state["duplicate_ip"])
|
||||||
|
assert.Equal(t, state["domain"], state["duplicate_domain"])
|
||||||
|
|
||||||
|
// URIs
|
||||||
|
assert.NotContains(t, state["stun_uri"], "stun.example.com")
|
||||||
|
assert.NotContains(t, state["turns_uri"], "turns.example.com")
|
||||||
|
assert.NotContains(t, state["http_uri"], "web.example.com")
|
||||||
|
assert.NotContains(t, state["https_uri"], "secure.example.com")
|
||||||
|
|
||||||
|
// Non-IP strings should remain unchanged
|
||||||
|
assert.Equal(t, "300.300.300.300", state["not_ip"])
|
||||||
|
assert.Equal(t, "192.168", state["partial_ip"])
|
||||||
|
assert.Equal(t, "1234.5678", state["ip_like_string"])
|
||||||
|
|
||||||
|
// Mixed content should have IPs and domains replaced
|
||||||
|
mixedContent := state["mixed_content"].(string)
|
||||||
|
assert.NotContains(t, mixedContent, "203.0.113.1")
|
||||||
|
assert.NotContains(t, mixedContent, "test.example.com")
|
||||||
|
assert.Contains(t, mixedContent, "Server at ")
|
||||||
|
assert.Contains(t, mixedContent, " on port 80")
|
||||||
|
|
||||||
|
// Special values should remain unchanged
|
||||||
|
assert.Equal(t, "", state["empty_string"])
|
||||||
|
assert.Nil(t, state["null_value"])
|
||||||
|
assert.Equal(t, float64(42), state["numeric_value"])
|
||||||
|
assert.Equal(t, true, state["boolean_value"])
|
||||||
|
|
||||||
|
// Check route state
|
||||||
|
var routeState map[string]any
|
||||||
|
err = json.Unmarshal(testState["route_state"], &routeState)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
routes := routeState["routes"].([]any)
|
||||||
|
route1 := routes[0].(map[string]any)
|
||||||
|
assert.NotEqual(t, "203.0.113.0/24", route1["network"])
|
||||||
|
assert.Contains(t, route1["network"], "/24")
|
||||||
|
assert.NotEqual(t, "203.0.113.1", route1["gateway"])
|
||||||
|
domains := route1["domains"].([]any)
|
||||||
|
assert.True(t, strings.HasSuffix(domains[0].(string), ".domain"))
|
||||||
|
assert.True(t, strings.HasSuffix(domains[1].(string), ".domain"))
|
||||||
|
|
||||||
|
// Check map keys are anonymized
|
||||||
|
refCountMap := routeState["refCountMap"].(map[string]any)
|
||||||
|
hasPublicIPKey := false
|
||||||
|
hasIPv6Key := false
|
||||||
|
hasPrivateIPKey := false
|
||||||
|
for key := range refCountMap {
|
||||||
|
if strings.Contains(key, "203.0.113.1") {
|
||||||
|
hasPublicIPKey = true
|
||||||
|
}
|
||||||
|
if strings.Contains(key, "2001:db8::1") {
|
||||||
|
hasIPv6Key = true
|
||||||
|
}
|
||||||
|
if key == "10.0.0.1/32" {
|
||||||
|
hasPrivateIPKey = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.False(t, hasPublicIPKey, "public IP in key should be anonymized")
|
||||||
|
assert.False(t, hasIPv6Key, "IPv6 in key should be anonymized")
|
||||||
|
assert.True(t, hasPrivateIPKey, "private IP in key should remain unchanged")
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustMarshal(v any) json.RawMessage {
|
||||||
|
data, err := json.Marshal(v)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return data
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnonymizeNetworkMap(t *testing.T) {
|
||||||
|
networkMap := &mgmProto.NetworkMap{
|
||||||
|
PeerConfig: &mgmProto.PeerConfig{
|
||||||
|
Address: "203.0.113.5",
|
||||||
|
Dns: "1.2.3.4",
|
||||||
|
Fqdn: "peer1.corp.example.com",
|
||||||
|
SshConfig: &mgmProto.SSHConfig{
|
||||||
|
SshPubKey: []byte("ssh-rsa AAAAB3NzaC1..."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
RemotePeers: []*mgmProto.RemotePeerConfig{
|
||||||
|
{
|
||||||
|
AllowedIps: []string{
|
||||||
|
"203.0.113.1/32",
|
||||||
|
"2001:db8:1234::1/128",
|
||||||
|
"192.168.1.1/32",
|
||||||
|
"100.64.0.1/32",
|
||||||
|
"10.0.0.1/32",
|
||||||
|
},
|
||||||
|
Fqdn: "peer2.corp.example.com",
|
||||||
|
SshConfig: &mgmProto.SSHConfig{
|
||||||
|
SshPubKey: []byte("ssh-rsa AAAAB3NzaC2..."),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
Routes: []*mgmProto.Route{
|
||||||
|
{
|
||||||
|
Network: "197.51.100.0/24",
|
||||||
|
Domains: []string{"prod.example.com", "staging.example.com"},
|
||||||
|
NetID: "net-123abc",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
DNSConfig: &mgmProto.DNSConfig{
|
||||||
|
NameServerGroups: []*mgmProto.NameServerGroup{
|
||||||
|
{
|
||||||
|
NameServers: []*mgmProto.NameServer{
|
||||||
|
{IP: "8.8.8.8"},
|
||||||
|
{IP: "1.1.1.1"},
|
||||||
|
{IP: "203.0.113.53"},
|
||||||
|
},
|
||||||
|
Domains: []string{"example.com", "internal.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
CustomZones: []*mgmProto.CustomZone{
|
||||||
|
{
|
||||||
|
Domain: "custom.example.com",
|
||||||
|
Records: []*mgmProto.SimpleRecord{
|
||||||
|
{
|
||||||
|
Name: "www.custom.example.com",
|
||||||
|
Type: 1,
|
||||||
|
RData: "203.0.113.10",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
Name: "internal.custom.example.com",
|
||||||
|
Type: 1,
|
||||||
|
RData: "192.168.1.10",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create anonymizer with test addresses
|
||||||
|
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||||
|
|
||||||
|
// Anonymize the network map
|
||||||
|
err := anonymizeNetworkMap(networkMap, anonymizer)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test PeerConfig anonymization
|
||||||
|
peerCfg := networkMap.PeerConfig
|
||||||
|
require.NotEqual(t, "203.0.113.5", peerCfg.Address)
|
||||||
|
|
||||||
|
// Verify DNS and FQDN are properly anonymized
|
||||||
|
require.NotEqual(t, "1.2.3.4", peerCfg.Dns)
|
||||||
|
require.NotEqual(t, "peer1.corp.example.com", peerCfg.Fqdn)
|
||||||
|
require.True(t, strings.HasSuffix(peerCfg.Fqdn, ".domain"))
|
||||||
|
|
||||||
|
// Verify SSH key is replaced
|
||||||
|
require.Equal(t, []byte("ssh-placeholder-key"), peerCfg.SshConfig.SshPubKey)
|
||||||
|
|
||||||
|
// Test RemotePeers anonymization
|
||||||
|
remotePeer := networkMap.RemotePeers[0]
|
||||||
|
|
||||||
|
// Verify FQDN is anonymized
|
||||||
|
require.NotEqual(t, "peer2.corp.example.com", remotePeer.Fqdn)
|
||||||
|
require.True(t, strings.HasSuffix(remotePeer.Fqdn, ".domain"))
|
||||||
|
|
||||||
|
// Check that public IPs are anonymized but private IPs are preserved
|
||||||
|
for _, allowedIP := range remotePeer.AllowedIps {
|
||||||
|
ip, _, err := net.ParseCIDR(allowedIP)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if ip.IsPrivate() || isInCGNATRange(ip) {
|
||||||
|
require.Contains(t, []string{
|
||||||
|
"192.168.1.1/32",
|
||||||
|
"100.64.0.1/32",
|
||||||
|
"10.0.0.1/32",
|
||||||
|
}, allowedIP)
|
||||||
|
} else {
|
||||||
|
require.NotContains(t, []string{
|
||||||
|
"203.0.113.1/32",
|
||||||
|
"2001:db8:1234::1/128",
|
||||||
|
}, allowedIP)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test Routes anonymization
|
||||||
|
route := networkMap.Routes[0]
|
||||||
|
require.NotEqual(t, "197.51.100.0/24", route.Network)
|
||||||
|
for _, domain := range route.Domains {
|
||||||
|
require.True(t, strings.HasSuffix(domain, ".domain"))
|
||||||
|
require.NotContains(t, domain, "example.com")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test DNS config anonymization
|
||||||
|
dnsConfig := networkMap.DNSConfig
|
||||||
|
nameServerGroup := dnsConfig.NameServerGroups[0]
|
||||||
|
|
||||||
|
// Verify well-known DNS servers are preserved
|
||||||
|
require.Equal(t, "8.8.8.8", nameServerGroup.NameServers[0].IP)
|
||||||
|
require.Equal(t, "1.1.1.1", nameServerGroup.NameServers[1].IP)
|
||||||
|
|
||||||
|
// Verify public DNS server is anonymized
|
||||||
|
require.NotEqual(t, "203.0.113.53", nameServerGroup.NameServers[2].IP)
|
||||||
|
|
||||||
|
// Verify domains are anonymized
|
||||||
|
for _, domain := range nameServerGroup.Domains {
|
||||||
|
require.True(t, strings.HasSuffix(domain, ".domain"))
|
||||||
|
require.NotContains(t, domain, "example.com")
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test CustomZones anonymization
|
||||||
|
customZone := dnsConfig.CustomZones[0]
|
||||||
|
require.True(t, strings.HasSuffix(customZone.Domain, ".domain"))
|
||||||
|
require.NotContains(t, customZone.Domain, "example.com")
|
||||||
|
|
||||||
|
// Verify records are properly anonymized
|
||||||
|
for _, record := range customZone.Records {
|
||||||
|
require.True(t, strings.HasSuffix(record.Name, ".domain"))
|
||||||
|
require.NotContains(t, record.Name, "example.com")
|
||||||
|
|
||||||
|
ip := net.ParseIP(record.RData)
|
||||||
|
if ip != nil {
|
||||||
|
if !ip.IsPrivate() {
|
||||||
|
require.NotEqual(t, "203.0.113.10", record.RData)
|
||||||
|
} else {
|
||||||
|
require.Equal(t, "192.168.1.10", record.RData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper function to check if IP is in CGNAT range
|
||||||
|
func isInCGNATRange(ip net.IP) bool {
|
||||||
|
cgnat := net.IPNet{
|
||||||
|
IP: net.ParseIP("100.64.0.0"),
|
||||||
|
Mask: net.CIDRMask(10, 32),
|
||||||
|
}
|
||||||
|
return cgnat.Contains(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestAnonymizeFirewallRules(t *testing.T) {
|
||||||
|
// TODO: Add ipv6
|
||||||
|
|
||||||
|
// Example iptables-save output
|
||||||
|
iptablesSave := `# Generated by iptables-save v1.8.7 on Thu Dec 19 10:00:00 2024
|
||||||
|
*filter
|
||||||
|
:INPUT ACCEPT [0:0]
|
||||||
|
:FORWARD ACCEPT [0:0]
|
||||||
|
:OUTPUT ACCEPT [0:0]
|
||||||
|
-A INPUT -s 192.168.1.0/24 -j ACCEPT
|
||||||
|
-A INPUT -s 44.192.140.1/32 -j DROP
|
||||||
|
-A FORWARD -s 10.0.0.0/8 -j DROP
|
||||||
|
-A FORWARD -s 44.192.140.0/24 -d 52.84.12.34/24 -j ACCEPT
|
||||||
|
COMMIT
|
||||||
|
|
||||||
|
*nat
|
||||||
|
:PREROUTING ACCEPT [0:0]
|
||||||
|
:INPUT ACCEPT [0:0]
|
||||||
|
:OUTPUT ACCEPT [0:0]
|
||||||
|
:POSTROUTING ACCEPT [0:0]
|
||||||
|
-A POSTROUTING -s 192.168.100.0/24 -j MASQUERADE
|
||||||
|
-A PREROUTING -d 44.192.140.10/32 -p tcp -m tcp --dport 80 -j DNAT --to-destination 192.168.1.10:80
|
||||||
|
COMMIT`
|
||||||
|
|
||||||
|
// Example iptables -v -n -L output
|
||||||
|
iptablesVerbose := `Chain INPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||||
|
pkts bytes target prot opt in out source destination
|
||||||
|
0 0 ACCEPT all -- * * 192.168.1.0/24 0.0.0.0/0
|
||||||
|
100 1024 DROP all -- * * 44.192.140.1 0.0.0.0/0
|
||||||
|
|
||||||
|
Chain FORWARD (policy ACCEPT 0 packets, 0 bytes)
|
||||||
|
pkts bytes target prot opt in out source destination
|
||||||
|
0 0 DROP all -- * * 10.0.0.0/8 0.0.0.0/0
|
||||||
|
25 256 ACCEPT all -- * * 44.192.140.0/24 52.84.12.34/24
|
||||||
|
|
||||||
|
Chain OUTPUT (policy ACCEPT 0 packets, 0 bytes)
|
||||||
|
pkts bytes target prot opt in out source destination`
|
||||||
|
|
||||||
|
// Example nftables output
|
||||||
|
nftablesRules := `table inet filter {
|
||||||
|
chain input {
|
||||||
|
type filter hook input priority filter; policy accept;
|
||||||
|
ip saddr 192.168.1.1 accept
|
||||||
|
ip saddr 44.192.140.1 drop
|
||||||
|
}
|
||||||
|
chain forward {
|
||||||
|
type filter hook forward priority filter; policy accept;
|
||||||
|
ip saddr 10.0.0.0/8 drop
|
||||||
|
ip saddr 44.192.140.0/24 ip daddr 52.84.12.34/24 accept
|
||||||
|
}
|
||||||
|
}`
|
||||||
|
|
||||||
|
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
||||||
|
|
||||||
|
// Test iptables-save anonymization
|
||||||
|
anonIptablesSave := anonymizer.AnonymizeString(iptablesSave)
|
||||||
|
|
||||||
|
// Private IP addresses should remain unchanged
|
||||||
|
assert.Contains(t, anonIptablesSave, "192.168.1.0/24")
|
||||||
|
assert.Contains(t, anonIptablesSave, "10.0.0.0/8")
|
||||||
|
assert.Contains(t, anonIptablesSave, "192.168.100.0/24")
|
||||||
|
assert.Contains(t, anonIptablesSave, "192.168.1.10")
|
||||||
|
|
||||||
|
// Public IP addresses should be anonymized to the default range
|
||||||
|
assert.NotContains(t, anonIptablesSave, "44.192.140.1")
|
||||||
|
assert.NotContains(t, anonIptablesSave, "44.192.140.0/24")
|
||||||
|
assert.NotContains(t, anonIptablesSave, "52.84.12.34")
|
||||||
|
assert.Contains(t, anonIptablesSave, "198.51.100.") // Default anonymous range
|
||||||
|
|
||||||
|
// Structure should be preserved
|
||||||
|
assert.Contains(t, anonIptablesSave, "*filter")
|
||||||
|
assert.Contains(t, anonIptablesSave, ":INPUT ACCEPT [0:0]")
|
||||||
|
assert.Contains(t, anonIptablesSave, "COMMIT")
|
||||||
|
assert.Contains(t, anonIptablesSave, "-j MASQUERADE")
|
||||||
|
assert.Contains(t, anonIptablesSave, "--dport 80")
|
||||||
|
|
||||||
|
// Test iptables verbose output anonymization
|
||||||
|
anonIptablesVerbose := anonymizer.AnonymizeString(iptablesVerbose)
|
||||||
|
|
||||||
|
// Private IP addresses should remain unchanged
|
||||||
|
assert.Contains(t, anonIptablesVerbose, "192.168.1.0/24")
|
||||||
|
assert.Contains(t, anonIptablesVerbose, "10.0.0.0/8")
|
||||||
|
|
||||||
|
// Public IP addresses should be anonymized to the default range
|
||||||
|
assert.NotContains(t, anonIptablesVerbose, "44.192.140.1")
|
||||||
|
assert.NotContains(t, anonIptablesVerbose, "44.192.140.0/24")
|
||||||
|
assert.NotContains(t, anonIptablesVerbose, "52.84.12.34")
|
||||||
|
assert.Contains(t, anonIptablesVerbose, "198.51.100.") // Default anonymous range
|
||||||
|
|
||||||
|
// Structure and counters should be preserved
|
||||||
|
assert.Contains(t, anonIptablesVerbose, "Chain INPUT (policy ACCEPT 0 packets, 0 bytes)")
|
||||||
|
assert.Contains(t, anonIptablesVerbose, "100 1024 DROP")
|
||||||
|
assert.Contains(t, anonIptablesVerbose, "pkts bytes target")
|
||||||
|
|
||||||
|
// Test nftables anonymization
|
||||||
|
anonNftables := anonymizer.AnonymizeString(nftablesRules)
|
||||||
|
|
||||||
|
// Private IP addresses should remain unchanged
|
||||||
|
assert.Contains(t, anonNftables, "192.168.1.1")
|
||||||
|
assert.Contains(t, anonNftables, "10.0.0.0/8")
|
||||||
|
|
||||||
|
// Public IP addresses should be anonymized to the default range
|
||||||
|
assert.NotContains(t, anonNftables, "44.192.140.1")
|
||||||
|
assert.NotContains(t, anonNftables, "44.192.140.0/24")
|
||||||
|
assert.NotContains(t, anonNftables, "52.84.12.34")
|
||||||
|
assert.Contains(t, anonNftables, "198.51.100.") // Default anonymous range
|
||||||
|
|
||||||
|
// Structure should be preserved
|
||||||
|
assert.Contains(t, anonNftables, "table inet filter {")
|
||||||
|
assert.Contains(t, anonNftables, "chain input {")
|
||||||
|
assert.Contains(t, anonNftables, "type filter hook input priority filter; policy accept;")
|
||||||
|
}
|
||||||
66
client/internal/debug/wgshow.go
Normal file
66
client/internal/debug/wgshow.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
|
)
|
||||||
|
|
||||||
|
type WGIface interface {
|
||||||
|
FullStats() (*configurer.Stats, error)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addWgShow() error {
|
||||||
|
result, err := g.statusRecorder.PeersStatus()
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
output := g.toWGShowFormat(result)
|
||||||
|
reader := bytes.NewReader([]byte(output))
|
||||||
|
|
||||||
|
if err := g.addFileToZip(reader, "wgshow.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add wg show to zip: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (g *BundleGenerator) toWGShowFormat(s *configurer.Stats) string {
|
||||||
|
var sb strings.Builder
|
||||||
|
|
||||||
|
sb.WriteString(fmt.Sprintf("interface: %s\n", s.DeviceName))
|
||||||
|
sb.WriteString(fmt.Sprintf(" public key: %s\n", s.PublicKey))
|
||||||
|
sb.WriteString(fmt.Sprintf(" listen port: %d\n", s.ListenPort))
|
||||||
|
if s.FWMark != 0 {
|
||||||
|
sb.WriteString(fmt.Sprintf(" fwmark: %#x\n", s.FWMark))
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, peer := range s.Peers {
|
||||||
|
sb.WriteString(fmt.Sprintf("\npeer: %s\n", peer.PublicKey))
|
||||||
|
if peer.Endpoint.IP != nil {
|
||||||
|
if g.anonymize {
|
||||||
|
anonEndpoint := g.anonymizer.AnonymizeUDPAddr(peer.Endpoint)
|
||||||
|
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", anonEndpoint.String()))
|
||||||
|
} else {
|
||||||
|
sb.WriteString(fmt.Sprintf(" endpoint: %s\n", peer.Endpoint.String()))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if len(peer.AllowedIPs) > 0 {
|
||||||
|
var ipStrings []string
|
||||||
|
for _, ipnet := range peer.AllowedIPs {
|
||||||
|
ipStrings = append(ipStrings, ipnet.String())
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf(" allowed ips: %s\n", strings.Join(ipStrings, ", ")))
|
||||||
|
}
|
||||||
|
sb.WriteString(fmt.Sprintf(" latest handshake: %s\n", peer.LastHandshake.Format(time.RFC1123)))
|
||||||
|
sb.WriteString(fmt.Sprintf(" transfer: %d B received, %d B sent\n", peer.RxBytes, peer.TxBytes))
|
||||||
|
if peer.PresharedKey {
|
||||||
|
sb.WriteString(" preshared key: (hidden)\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return sb.String()
|
||||||
|
}
|
||||||
@@ -2,7 +2,7 @@ package internal
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
@@ -12,13 +12,14 @@ import (
|
|||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.SimpleRecord, bool) {
|
func createPTRRecord(aRecord nbdns.SimpleRecord, prefix netip.Prefix) (nbdns.SimpleRecord, bool) {
|
||||||
ip := net.ParseIP(aRecord.RData)
|
ip, err := netip.ParseAddr(aRecord.RData)
|
||||||
if ip == nil || ip.To4() == nil {
|
if err != nil {
|
||||||
|
log.Warnf("failed to parse IP address %s: %v", aRecord.RData, err)
|
||||||
return nbdns.SimpleRecord{}, false
|
return nbdns.SimpleRecord{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !ipNet.Contains(ip) {
|
if !prefix.Contains(ip) {
|
||||||
return nbdns.SimpleRecord{}, false
|
return nbdns.SimpleRecord{}, false
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -36,16 +37,19 @@ func createPTRRecord(aRecord nbdns.SimpleRecord, ipNet *net.IPNet) (nbdns.Simple
|
|||||||
}
|
}
|
||||||
|
|
||||||
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
// generateReverseZoneName creates the reverse DNS zone name for a given network
|
||||||
func generateReverseZoneName(ipNet *net.IPNet) (string, error) {
|
func generateReverseZoneName(network netip.Prefix) (string, error) {
|
||||||
networkIP := ipNet.IP.Mask(ipNet.Mask)
|
networkIP := network.Masked().Addr()
|
||||||
maskOnes, _ := ipNet.Mask.Size()
|
|
||||||
|
if !networkIP.Is4() {
|
||||||
|
return "", fmt.Errorf("reverse DNS is only supported for IPv4 networks, got: %s", networkIP)
|
||||||
|
}
|
||||||
|
|
||||||
// round up to nearest byte
|
// round up to nearest byte
|
||||||
octetsToUse := (maskOnes + 7) / 8
|
octetsToUse := (network.Bits() + 7) / 8
|
||||||
|
|
||||||
octets := strings.Split(networkIP.String(), ".")
|
octets := strings.Split(networkIP.String(), ".")
|
||||||
if octetsToUse > len(octets) {
|
if octetsToUse > len(octets) {
|
||||||
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", maskOnes)
|
return "", fmt.Errorf("invalid network mask size for reverse DNS: %d", network.Bits())
|
||||||
}
|
}
|
||||||
|
|
||||||
reverseOctets := make([]string, octetsToUse)
|
reverseOctets := make([]string, octetsToUse)
|
||||||
@@ -68,7 +72,7 @@ func zoneExists(config *nbdns.Config, zoneName string) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// collectPTRRecords gathers all PTR records for the given network from A records
|
// collectPTRRecords gathers all PTR records for the given network from A records
|
||||||
func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRecord {
|
func collectPTRRecords(config *nbdns.Config, prefix netip.Prefix) []nbdns.SimpleRecord {
|
||||||
var records []nbdns.SimpleRecord
|
var records []nbdns.SimpleRecord
|
||||||
|
|
||||||
for _, zone := range config.CustomZones {
|
for _, zone := range config.CustomZones {
|
||||||
@@ -77,7 +81,7 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if ptrRecord, ok := createPTRRecord(record, ipNet); ok {
|
if ptrRecord, ok := createPTRRecord(record, prefix); ok {
|
||||||
records = append(records, ptrRecord)
|
records = append(records, ptrRecord)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -87,8 +91,8 @@ func collectPTRRecords(config *nbdns.Config, ipNet *net.IPNet) []nbdns.SimpleRec
|
|||||||
}
|
}
|
||||||
|
|
||||||
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
// addReverseZone adds a reverse DNS zone to the configuration for the given network
|
||||||
func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
func addReverseZone(config *nbdns.Config, network netip.Prefix) {
|
||||||
zoneName, err := generateReverseZoneName(ipNet)
|
zoneName, err := generateReverseZoneName(network)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warn(err)
|
log.Warn(err)
|
||||||
return
|
return
|
||||||
@@ -99,7 +103,7 @@ func addReverseZone(config *nbdns.Config, ipNet *net.IPNet) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
records := collectPTRRecords(config, ipNet)
|
records := collectPTRRecords(config, network)
|
||||||
|
|
||||||
reverseZone := nbdns.CustomZone{
|
reverseZone := nbdns.CustomZone{
|
||||||
Domain: zoneName,
|
Domain: zoneName,
|
||||||
|
|||||||
@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||||
}
|
}
|
||||||
return listOfDomains
|
return listOfDomains
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,15 @@
|
|||||||
package dns
|
package dns
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"slices"
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -23,8 +26,8 @@ type SubdomainMatcher interface {
|
|||||||
type HandlerEntry struct {
|
type HandlerEntry struct {
|
||||||
Handler dns.Handler
|
Handler dns.Handler
|
||||||
Priority int
|
Priority int
|
||||||
Pattern string
|
Pattern domain.Domain
|
||||||
OrigPattern string
|
OrigPattern domain.Domain
|
||||||
IsWildcard bool
|
IsWildcard bool
|
||||||
MatchSubdomains bool
|
MatchSubdomains bool
|
||||||
}
|
}
|
||||||
@@ -38,7 +41,7 @@ type HandlerChain struct {
|
|||||||
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
// ResponseWriterChain wraps a dns.ResponseWriter to track if handler wants to continue chain
|
||||||
type ResponseWriterChain struct {
|
type ResponseWriterChain struct {
|
||||||
dns.ResponseWriter
|
dns.ResponseWriter
|
||||||
origPattern string
|
origPattern domain.Domain
|
||||||
shouldContinue bool
|
shouldContinue bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -58,29 +61,24 @@ func NewHandlerChain() *HandlerChain {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// GetOrigPattern returns the original pattern of the handler that wrote the response
|
// GetOrigPattern returns the original pattern of the handler that wrote the response
|
||||||
func (w *ResponseWriterChain) GetOrigPattern() string {
|
func (w *ResponseWriterChain) GetOrigPattern() domain.Domain {
|
||||||
return w.origPattern
|
return w.origPattern
|
||||||
}
|
}
|
||||||
|
|
||||||
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
// AddHandler adds a new handler to the chain, replacing any existing handler with the same pattern and priority
|
||||||
func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority int) {
|
func (c *HandlerChain) AddHandler(pattern domain.Domain, handler dns.Handler, priority int) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
pattern = domain.Domain(strings.ToLower(dns.Fqdn(pattern.PunycodeString())))
|
||||||
origPattern := pattern
|
origPattern := pattern
|
||||||
isWildcard := strings.HasPrefix(pattern, "*.")
|
isWildcard := strings.HasPrefix(pattern.PunycodeString(), "*.")
|
||||||
if isWildcard {
|
if isWildcard {
|
||||||
pattern = pattern[2:]
|
pattern = pattern[2:]
|
||||||
}
|
}
|
||||||
|
|
||||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
c.removeEntry(origPattern, priority)
|
||||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if handler implements SubdomainMatcher interface
|
// Check if handler implements SubdomainMatcher interface
|
||||||
matchSubdomains := false
|
matchSubdomains := false
|
||||||
@@ -114,8 +112,8 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
|||||||
|
|
||||||
// domain specificity next
|
// domain specificity next
|
||||||
if h.Priority == newEntry.Priority {
|
if h.Priority == newEntry.Priority {
|
||||||
newDots := strings.Count(newEntry.Pattern, ".")
|
newDots := strings.Count(newEntry.Pattern.PunycodeString(), ".")
|
||||||
existingDots := strings.Count(h.Pattern, ".")
|
existingDots := strings.Count(h.Pattern.PunycodeString(), ".")
|
||||||
if newDots > existingDots {
|
if newDots > existingDots {
|
||||||
return i
|
return i
|
||||||
}
|
}
|
||||||
@@ -127,83 +125,53 @@ func (c *HandlerChain) findHandlerPosition(newEntry HandlerEntry) int {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// RemoveHandler removes a handler for the given pattern and priority
|
// RemoveHandler removes a handler for the given pattern and priority
|
||||||
func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
func (c *HandlerChain) RemoveHandler(pattern domain.Domain, priority int) {
|
||||||
c.mu.Lock()
|
c.mu.Lock()
|
||||||
defer c.mu.Unlock()
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
pattern = dns.Fqdn(pattern)
|
pattern = domain.Domain(dns.Fqdn(pattern.PunycodeString()))
|
||||||
|
|
||||||
|
c.removeEntry(pattern, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) removeEntry(pattern domain.Domain, priority int) {
|
||||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern.PunycodeString(), pattern.PunycodeString()) && entry.Priority == priority {
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
|
||||||
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
|
||||||
for _, entry := range c.handlers {
|
|
||||||
if strings.EqualFold(entry.Pattern, pattern) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
qname := strings.ToLower(r.Question[0].Name)
|
qname := strings.ToLower(r.Question[0].Name)
|
||||||
log.Tracef("handling DNS request for domain=%s", qname)
|
|
||||||
|
|
||||||
c.mu.RLock()
|
c.mu.RLock()
|
||||||
handlers := slices.Clone(c.handlers)
|
handlers := slices.Clone(c.handlers)
|
||||||
c.mu.RUnlock()
|
c.mu.RUnlock()
|
||||||
|
|
||||||
if log.IsLevelEnabled(log.TraceLevel) {
|
if log.IsLevelEnabled(log.TraceLevel) {
|
||||||
log.Tracef("current handlers (%d):", len(handlers))
|
var b strings.Builder
|
||||||
|
b.WriteString(fmt.Sprintf("DNS request domain=%s, handlers (%d):\n", qname, len(handlers)))
|
||||||
for _, h := range handlers {
|
for _, h := range handlers {
|
||||||
log.Tracef(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
b.WriteString(fmt.Sprintf(" - pattern: domain=%s original: domain=%s wildcard=%v match_subdomain=%v priority=%d\n",
|
||||||
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority)
|
h.Pattern, h.OrigPattern, h.IsWildcard, h.MatchSubdomains, h.Priority))
|
||||||
}
|
}
|
||||||
|
log.Trace(strings.TrimSuffix(b.String(), "\n"))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Try handlers in priority order
|
// Try handlers in priority order
|
||||||
for _, entry := range handlers {
|
for _, entry := range handlers {
|
||||||
var matched bool
|
matched := c.isHandlerMatch(qname, entry)
|
||||||
switch {
|
|
||||||
case entry.Pattern == ".":
|
|
||||||
matched = true
|
|
||||||
case entry.IsWildcard:
|
|
||||||
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern), ".")
|
|
||||||
matched = len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern)
|
|
||||||
default:
|
|
||||||
// For non-wildcard patterns:
|
|
||||||
// If handler wants subdomain matching, allow suffix match
|
|
||||||
// Otherwise require exact match
|
|
||||||
if entry.MatchSubdomains {
|
|
||||||
matched = strings.EqualFold(qname, entry.Pattern) || strings.HasSuffix(qname, "."+entry.Pattern)
|
|
||||||
} else {
|
|
||||||
matched = strings.EqualFold(qname, entry.Pattern)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if !matched {
|
if matched {
|
||||||
log.Tracef("trying domain match: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d matched=false",
|
log.Tracef("handler matched: domain=%s -> pattern=%s wildcard=%v match_subdomain=%v priority=%d",
|
||||||
qname, entry.OrigPattern, entry.MatchSubdomains, entry.IsWildcard, entry.Priority)
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Tracef("handler matched: request: domain=%s pattern: domain=%s wildcard=%v match_subdomain=%v priority=%d",
|
|
||||||
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
qname, entry.OrigPattern, entry.IsWildcard, entry.MatchSubdomains, entry.Priority)
|
||||||
|
|
||||||
chainWriter := &ResponseWriterChain{
|
chainWriter := &ResponseWriterChain{
|
||||||
@@ -214,11 +182,12 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
|
|
||||||
// If handler wants to continue, try next handler
|
// If handler wants to continue, try next handler
|
||||||
if chainWriter.shouldContinue {
|
if chainWriter.shouldContinue {
|
||||||
log.Tracef("handler requested continue to next handler")
|
log.Tracef("handler requested continue to next handler for domain=%s", qname)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// No handler matched or all handlers passed
|
// No handler matched or all handlers passed
|
||||||
log.Tracef("no handler found for domain=%s", qname)
|
log.Tracef("no handler found for domain=%s", qname)
|
||||||
@@ -228,3 +197,22 @@ func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
log.Errorf("failed to write DNS response: %v", err)
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) isHandlerMatch(qname string, entry HandlerEntry) bool {
|
||||||
|
switch {
|
||||||
|
case entry.Pattern == ".":
|
||||||
|
return true
|
||||||
|
case entry.IsWildcard:
|
||||||
|
parts := strings.Split(strings.TrimSuffix(qname, entry.Pattern.PunycodeString()), ".")
|
||||||
|
return len(parts) >= 2 && strings.HasSuffix(qname, entry.Pattern.PunycodeString())
|
||||||
|
default:
|
||||||
|
// For non-wildcard patterns:
|
||||||
|
// If handler wants subdomain matching, allow suffix match
|
||||||
|
// Otherwise require exact match
|
||||||
|
if entry.MatchSubdomains {
|
||||||
|
return strings.EqualFold(qname, entry.Pattern.PunycodeString()) || strings.HasSuffix(qname, "."+entry.Pattern.PunycodeString())
|
||||||
|
} else {
|
||||||
|
return strings.EqualFold(qname, entry.Pattern.PunycodeString())
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,7 +1,6 @@
|
|||||||
package dns_test
|
package dns_test
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
@@ -9,6 +8,8 @@ import (
|
|||||||
"github.com/stretchr/testify/mock"
|
"github.com/stretchr/testify/mock"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
nbdns "github.com/netbirdio/netbird/client/internal/dns"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/dns/test"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
// TestHandlerChain_ServeDNS_Priorities tests that handlers are executed in priority order
|
||||||
@@ -30,7 +31,7 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
r.SetQuestion("example.com.", dns.TypeA)
|
r.SetQuestion("example.com.", dns.TypeA)
|
||||||
|
|
||||||
// Create test writer
|
// Create test writer
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup expectations - only highest priority handler should be called
|
// Setup expectations - only highest priority handler should be called
|
||||||
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
dnsRouteHandler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
@@ -50,8 +51,8 @@ func TestHandlerChain_ServeDNS_Priorities(t *testing.T) {
|
|||||||
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
handlerDomain string
|
handlerDomain domain.Domain
|
||||||
queryDomain string
|
queryDomain domain.Domain
|
||||||
isWildcard bool
|
isWildcard bool
|
||||||
matchSubdomains bool
|
matchSubdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -141,8 +142,8 @@ func TestHandlerChain_ServeDNS_DomainMatching(t *testing.T) {
|
|||||||
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
chain.AddHandler(pattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
@@ -160,17 +161,17 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
handlers []struct {
|
handlers []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
queryDomain string
|
queryDomain domain.Domain
|
||||||
expectedCalls int
|
expectedCalls int
|
||||||
expectedHandler int // index of the handler that should be called
|
expectedHandler int // index of the handler that should be called
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "wildcard and exact same priority - exact should win",
|
name: "wildcard and exact same priority - exact should win",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
@@ -183,7 +184,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "higher priority wildcard over lower priority exact",
|
name: "higher priority wildcard over lower priority exact",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "example.com.", priority: nbdns.PriorityDefault},
|
||||||
@@ -196,7 +197,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "multiple wildcards different priorities",
|
name: "multiple wildcards different priorities",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
@@ -210,7 +211,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "subdomain with mix of patterns",
|
name: "subdomain with mix of patterns",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
{pattern: "*.example.com.", priority: nbdns.PriorityDefault},
|
||||||
@@ -224,7 +225,7 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "root zone with specific domain",
|
name: "root zone with specific domain",
|
||||||
handlers: []struct {
|
handlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{pattern: ".", priority: nbdns.PriorityDefault},
|
{pattern: ".", priority: nbdns.PriorityDefault},
|
||||||
@@ -258,8 +259,8 @@ func TestHandlerChain_ServeDNS_OverlappingDomains(t *testing.T) {
|
|||||||
|
|
||||||
// Create and execute request
|
// Create and execute request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.queryDomain, dns.TypeA)
|
r.SetQuestion(tt.queryDomain.PunycodeString(), dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
// Verify expectations
|
// Verify expectations
|
||||||
@@ -316,7 +317,7 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
}).Once()
|
}).Once()
|
||||||
|
|
||||||
// Execute
|
// Execute
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
// Verify all handlers were called in order
|
// Verify all handlers were called in order
|
||||||
@@ -325,26 +326,12 @@ func TestHandlerChain_ServeDNS_ChainContinuation(t *testing.T) {
|
|||||||
handler3.AssertExpectations(t)
|
handler3.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// mockResponseWriter implements dns.ResponseWriter for testing
|
|
||||||
type mockResponseWriter struct {
|
|
||||||
mock.Mock
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *mockResponseWriter) LocalAddr() net.Addr { return nil }
|
|
||||||
func (m *mockResponseWriter) RemoteAddr() net.Addr { return nil }
|
|
||||||
func (m *mockResponseWriter) WriteMsg(*dns.Msg) error { return nil }
|
|
||||||
func (m *mockResponseWriter) Write([]byte) (int, error) { return 0, nil }
|
|
||||||
func (m *mockResponseWriter) Close() error { return nil }
|
|
||||||
func (m *mockResponseWriter) TsigStatus() error { return nil }
|
|
||||||
func (m *mockResponseWriter) TsigTimersOnly(bool) {}
|
|
||||||
func (m *mockResponseWriter) Hijack() {}
|
|
||||||
|
|
||||||
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
ops []struct {
|
ops []struct {
|
||||||
action string // "add" or "remove"
|
action string // "add" or "remove"
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}
|
}
|
||||||
query string
|
query string
|
||||||
@@ -354,7 +341,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
name: "remove high priority keeps lower priority handler",
|
name: "remove high priority keeps lower priority handler",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
@@ -371,7 +358,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
name: "remove lower priority keeps high priority handler",
|
name: "remove lower priority keeps high priority handler",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
@@ -388,7 +375,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
name: "remove all handlers in order",
|
name: "remove all handlers in order",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
}{
|
}{
|
||||||
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
{"add", "example.com.", nbdns.PriorityDNSRoute},
|
||||||
@@ -425,7 +412,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
// Create test request
|
// Create test request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup expectations
|
// Setup expectations
|
||||||
for priority, handler := range handlers {
|
for priority, handler := range handlers {
|
||||||
@@ -443,14 +430,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
for _, handler := range handlers {
|
for _, handler := range handlers {
|
||||||
handler.AssertExpectations(t)
|
handler.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify handler exists check
|
|
||||||
for priority, shouldExist := range tt.expectedCalls {
|
|
||||||
if shouldExist {
|
|
||||||
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
|
||||||
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -458,7 +437,7 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
testDomain := "example.com."
|
testDomain := domain.Domain("example.com.")
|
||||||
testQuery := "test.example.com."
|
testQuery := "test.example.com."
|
||||||
|
|
||||||
// Create handlers with MatchSubdomains enabled
|
// Create handlers with MatchSubdomains enabled
|
||||||
@@ -470,45 +449,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(testQuery, dns.TypeA)
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
|
// Keep track of mocks for the final assertion in Step 4
|
||||||
|
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
|
||||||
|
|
||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state with all three handlers
|
// Test 1: Initial state
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Highest priority handler (routeHandler) should be called
|
// Highest priority handler (routeHandler) should be called
|
||||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w1, r)
|
||||||
routeHandler.AssertExpectations(t)
|
routeHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
routeHandler.ExpectedCalls = nil
|
||||||
|
routeHandler.Calls = nil
|
||||||
|
matchHandler.ExpectedCalls = nil
|
||||||
|
matchHandler.Calls = nil
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 2: Remove highest priority handler
|
// Test 2: Remove highest priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||||
assert.True(t, chain.HasHandlers(testDomain))
|
|
||||||
|
|
||||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Now middle priority handler (matchHandler) should be called
|
// Now middle priority handler (matchHandler) should be called
|
||||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w2, r)
|
||||||
matchHandler.AssertExpectations(t)
|
matchHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
matchHandler.ExpectedCalls = nil
|
||||||
|
matchHandler.Calls = nil
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 3: Remove middle priority handler
|
// Test 3: Remove middle priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||||
assert.True(t, chain.HasHandlers(testDomain))
|
|
||||||
|
|
||||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
// Now lowest priority handler (defaultHandler) should be called
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w3, r)
|
||||||
defaultHandler.AssertExpectations(t)
|
defaultHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 4: Remove last handler
|
// Test 4: Remove last handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
assert.False(t, chain.HasHandlers(testDomain))
|
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||||
|
|
||||||
|
for _, m := range mocks {
|
||||||
|
m.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||||
@@ -516,7 +519,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
scenario string
|
scenario string
|
||||||
addHandlers []struct {
|
addHandlers []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -528,7 +531,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "case insensitive exact match",
|
name: "case insensitive exact match",
|
||||||
scenario: "handler registered lowercase, query uppercase",
|
scenario: "handler registered lowercase, query uppercase",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -542,7 +545,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "case insensitive wildcard match",
|
name: "case insensitive wildcard match",
|
||||||
scenario: "handler registered mixed case wildcard, query different case",
|
scenario: "handler registered mixed case wildcard, query different case",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -556,7 +559,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "multiple handlers different case same domain",
|
name: "multiple handlers different case same domain",
|
||||||
scenario: "second handler should replace first despite case difference",
|
scenario: "second handler should replace first despite case difference",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -571,7 +574,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "subdomain matching case insensitive",
|
name: "subdomain matching case insensitive",
|
||||||
scenario: "handler with MatchSubdomains true should match regardless of case",
|
scenario: "handler with MatchSubdomains true should match regardless of case",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -585,7 +588,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "root zone case insensitive",
|
name: "root zone case insensitive",
|
||||||
scenario: "root zone handler should match regardless of case",
|
scenario: "root zone handler should match regardless of case",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -599,7 +602,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
name: "multiple handlers different priority",
|
name: "multiple handlers different priority",
|
||||||
scenario: "should call higher priority handler despite case differences",
|
scenario: "should call higher priority handler despite case differences",
|
||||||
addHandlers: []struct {
|
addHandlers: []struct {
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomains bool
|
subdomains bool
|
||||||
shouldMatch bool
|
shouldMatch bool
|
||||||
@@ -616,7 +619,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
handlerCalls := make(map[string]bool) // track which patterns were called
|
handlerCalls := make(map[domain.Domain]bool) // track which patterns were called
|
||||||
|
|
||||||
// Add handlers according to test case
|
// Add handlers according to test case
|
||||||
for _, h := range tt.addHandlers {
|
for _, h := range tt.addHandlers {
|
||||||
@@ -659,7 +662,7 @@ func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
|||||||
// Execute request
|
// Execute request
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query, dns.TypeA)
|
||||||
chain.ServeDNS(&mockResponseWriter{}, r)
|
chain.ServeDNS(&test.MockResponseWriter{}, r)
|
||||||
|
|
||||||
// Verify each handler was called exactly as expected
|
// Verify each handler was called exactly as expected
|
||||||
for _, h := range tt.addHandlers {
|
for _, h := range tt.addHandlers {
|
||||||
@@ -684,19 +687,19 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario string
|
scenario string
|
||||||
ops []struct {
|
ops []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}
|
}
|
||||||
query string
|
query domain.Domain
|
||||||
expectedMatch string
|
expectedMatch domain.Domain
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
name: "more specific domain matches first",
|
name: "more specific domain matches first",
|
||||||
scenario: "sub.example.com should match before example.com",
|
scenario: "sub.example.com should match before example.com",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -711,7 +714,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "sub.example.com should match before example.com",
|
scenario: "sub.example.com should match before example.com",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -726,7 +729,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "after removing most specific, should fall back to less specific",
|
scenario: "after removing most specific, should fall back to less specific",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -743,7 +746,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "less specific domain with higher priority should match first",
|
scenario: "less specific domain with higher priority should match first",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -758,7 +761,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "with equal priority, more specific domain should match",
|
scenario: "with equal priority, more specific domain should match",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -774,7 +777,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
scenario: "specific domain should match before wildcard at same priority",
|
scenario: "specific domain should match before wildcard at same priority",
|
||||||
ops: []struct {
|
ops: []struct {
|
||||||
action string
|
action string
|
||||||
pattern string
|
pattern domain.Domain
|
||||||
priority int
|
priority int
|
||||||
subdomain bool
|
subdomain bool
|
||||||
}{
|
}{
|
||||||
@@ -789,7 +792,7 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
chain := nbdns.NewHandlerChain()
|
chain := nbdns.NewHandlerChain()
|
||||||
handlers := make(map[string]*nbdns.MockSubdomainHandler)
|
handlers := make(map[domain.Domain]*nbdns.MockSubdomainHandler)
|
||||||
|
|
||||||
for _, op := range tt.ops {
|
for _, op := range tt.ops {
|
||||||
if op.action == "add" {
|
if op.action == "add" {
|
||||||
@@ -802,8 +805,8 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(tt.query, dns.TypeA)
|
r.SetQuestion(tt.query.PunycodeString(), dns.TypeA)
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
// Setup handler expectations
|
// Setup handler expectations
|
||||||
for pattern, handler := range handlers {
|
for pattern, handler := range handlers {
|
||||||
@@ -830,3 +833,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addPattern domain.Domain
|
||||||
|
removePattern domain.Domain
|
||||||
|
queryPattern domain.Domain
|
||||||
|
shouldBeRemoved bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact same pattern",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "example.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing with identical patterns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case difference",
|
||||||
|
addPattern: "Example.Com.",
|
||||||
|
removePattern: "EXAMPLE.COM.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with mixed case, removing with uppercase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reversed case difference",
|
||||||
|
addPattern: "EXAMPLE.ORG.",
|
||||||
|
removePattern: "example.org.",
|
||||||
|
queryPattern: "example.org.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with uppercase, removing with lowercase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add wildcard, remove wildcard",
|
||||||
|
addPattern: "*.example.com.",
|
||||||
|
removePattern: "*.example.com.",
|
||||||
|
queryPattern: "sub.example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing with identical wildcard patterns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add wildcard, remove transformed pattern",
|
||||||
|
addPattern: "*.example.net.",
|
||||||
|
removePattern: "example.net.",
|
||||||
|
queryPattern: "sub.example.net.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding with wildcard, removing with non-wildcard pattern",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add transformed pattern, remove wildcard",
|
||||||
|
addPattern: "example.io.",
|
||||||
|
removePattern: "*.example.io.",
|
||||||
|
queryPattern: "example.io.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trailing dot difference",
|
||||||
|
addPattern: "example.dev",
|
||||||
|
removePattern: "example.dev.",
|
||||||
|
queryPattern: "example.dev.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding without trailing dot, removing with trailing dot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reversed trailing dot difference",
|
||||||
|
addPattern: "example.app.",
|
||||||
|
removePattern: "example.app",
|
||||||
|
queryPattern: "example.app.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with trailing dot, removing without trailing dot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case and wildcard",
|
||||||
|
addPattern: "*.Example.Site.",
|
||||||
|
removePattern: "*.EXAMPLE.SITE.",
|
||||||
|
queryPattern: "sub.example.site.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding mixed case wildcard, removing uppercase wildcard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone",
|
||||||
|
addPattern: ".",
|
||||||
|
removePattern: ".",
|
||||||
|
queryPattern: "random.domain.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing root zone",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong domain",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "different.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding one domain, trying to remove a different domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain mismatch",
|
||||||
|
addPattern: "sub.example.com.",
|
||||||
|
removePattern: "example.com.",
|
||||||
|
queryPattern: "sub.example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding subdomain, trying to remove parent domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parent domain mismatch",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "sub.example.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding parent domain, trying to remove subdomain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
handler := &nbdns.MockHandler{}
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryPattern.PunycodeString(), dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &test.MockResponseWriter{}}
|
||||||
|
|
||||||
|
// First verify no handler is called before adding any
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
handler.AssertNotCalled(t, "ServeDNS")
|
||||||
|
|
||||||
|
// Add handler
|
||||||
|
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
// Verify handler is called after adding
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Reset mock for the next test
|
||||||
|
handler.ExpectedCalls = nil
|
||||||
|
|
||||||
|
// Remove handler
|
||||||
|
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
// Set up expectations based on whether removal should succeed
|
||||||
|
if !tt.shouldBeRemoved {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test if handler is still called after removal attempt
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
if tt.shouldBeRemoved {
|
||||||
|
handler.AssertNotCalled(t, "ServeDNS",
|
||||||
|
"Handler should not be called after successful removal with pattern %q",
|
||||||
|
tt.removePattern)
|
||||||
|
} else {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
handler.ExpectedCalls = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,15 +5,18 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ipv4ReverseZone = ".in-addr.arpa"
|
ipv4ReverseZone = ".in-addr.arpa."
|
||||||
ipv6ReverseZone = ".ip6.arpa"
|
ipv6ReverseZone = ".ip6.arpa."
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
@@ -38,7 +41,7 @@ type HostDNSConfig struct {
|
|||||||
|
|
||||||
type DomainConfig struct {
|
type DomainConfig struct {
|
||||||
Disabled bool `json:"disabled"`
|
Disabled bool `json:"disabled"`
|
||||||
Domain string `json:"domain"`
|
Domain domain.Domain `json:"domain"`
|
||||||
MatchOnly bool `json:"matchOnly"`
|
MatchOnly bool `json:"matchOnly"`
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -101,18 +104,20 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
config.RouteAll = true
|
config.RouteAll = true
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, domain := range nsConfig.Domains {
|
for _, d := range nsConfig.Domains {
|
||||||
|
d := strings.ToLower(dns.Fqdn(d.PunycodeString()))
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.TrimSuffix(domain, "."),
|
Domain: domain.Domain(d),
|
||||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
d := strings.ToLower(dns.Fqdn(customZone.Domain))
|
||||||
|
matchOnly := strings.HasSuffix(d, ipv4ReverseZone) || strings.HasSuffix(d, ipv6ReverseZone)
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
Domain: domain.Domain(d),
|
||||||
MatchOnly: matchOnly,
|
MatchOnly: matchOnly,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, dConf.Domain)
|
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain.PunycodeString(), "."))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain.PunycodeString(), "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user