mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 08:46:38 +00:00
Compare commits
306 Commits
v0.33.0
...
test/netwo
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
4d2c774378 | ||
|
|
ab2e3fec72 | ||
|
|
47f88f7057 | ||
|
|
ee33a6ed7c | ||
|
|
da662cfd08 | ||
|
|
ed2ee1ee9d | ||
|
|
76d73548d6 | ||
|
|
11828a064a | ||
|
|
0c2a3dd937 | ||
|
|
cd9eff5331 | ||
|
|
47dcf8d68c | ||
|
|
80ceb80197 | ||
|
|
cc8f6bcaf3 | ||
|
|
636a0e2475 | ||
|
|
e66e329bf6 | ||
|
|
aaa23beeec | ||
|
|
6bef474e9e | ||
|
|
81040ff80a | ||
|
|
c73481aee4 | ||
|
|
92286b2541 | ||
|
|
d8bcf745b0 | ||
|
|
8430139d80 | ||
|
|
a2962b4ce0 | ||
|
|
16fffdb75b | ||
|
|
036cecbf46 | ||
|
|
3482852bb6 | ||
|
|
fd62665b1f | ||
|
|
fc1da94520 | ||
|
|
1ffe48f0d4 | ||
|
|
a3b8a21385 | ||
|
|
86492b88c4 | ||
|
|
d08a629f9e | ||
|
|
36da464413 | ||
|
|
268e3404d3 | ||
|
|
54d0591833 | ||
|
|
ae6b61301c | ||
|
|
86370a0e7b | ||
|
|
a444e551b3 | ||
|
|
53b9a2002f | ||
|
|
cb16d0f45f | ||
|
|
e8d8bd8f18 | ||
|
|
8b07f21c28 | ||
|
|
54be772ffd | ||
|
|
4b76d93cec | ||
|
|
3c3a454e61 | ||
|
|
5ff77b3595 | ||
|
|
b180edbe5c | ||
|
|
de3b5c78d7 | ||
|
|
0b42f40cf6 | ||
|
|
062d1ec76f | ||
|
|
0a042ac36d | ||
|
|
c111675dd8 | ||
|
|
e7f921d787 | ||
|
|
e9f11fb11b | ||
|
|
419ed275fa | ||
|
|
2d4fcaf186 | ||
|
|
acf172b52c | ||
|
|
8c81a823fa | ||
|
|
619c549547 | ||
|
|
60ffe0dc87 | ||
|
|
9a713a0987 | ||
|
|
c4945cd565 | ||
|
|
1e10c17ecb | ||
|
|
bcc5824980 | ||
|
|
af5796de1c | ||
|
|
9d604b7e66 | ||
|
|
96d5190436 | ||
|
|
d19c26df06 | ||
|
|
36e36414d9 | ||
|
|
7e69589e05 | ||
|
|
aa613ab79a | ||
|
|
6ead0ff95e | ||
|
|
0db65a8984 | ||
|
|
c138807e95 | ||
|
|
637c0c8949 | ||
|
|
c72e13d8e6 | ||
|
|
f6d7bccfa0 | ||
|
|
e3ed01cafb | ||
|
|
fa748a7ec2 | ||
|
|
cccc615783 | ||
|
|
2021463ca0 | ||
|
|
f48cfd52e9 | ||
|
|
6838f53f40 | ||
|
|
8276236dfa | ||
|
|
994b923d56 | ||
|
|
59e2432231 | ||
|
|
eee0d123e4 | ||
|
|
e943203ae2 | ||
|
|
82c12cc8ae | ||
|
|
6a775217cf | ||
|
|
175674749f | ||
|
|
1e534cecf6 | ||
|
|
aa3aa8c6a8 | ||
|
|
fbdfe45c25 | ||
|
|
81ee172db8 | ||
|
|
f8fd65a65f | ||
|
|
62b978c050 | ||
|
|
266fdcd2ed | ||
|
|
0819df916e | ||
|
|
c8a558f797 | ||
|
|
dabdef4d67 | ||
|
|
cc48594b0b | ||
|
|
559e673107 | ||
|
|
b64bee35fa | ||
|
|
9a0354b681 | ||
|
|
73101c8977 | ||
|
|
73ce746ba7 | ||
|
|
a74208abac | ||
|
|
b307298b2f | ||
|
|
f00a997167 | ||
|
|
5134e3a06a | ||
|
|
6554026a82 | ||
|
|
a854660402 | ||
|
|
a0b48f971c | ||
|
|
96de928cb3 | ||
|
|
4ebf1410c6 | ||
|
|
77e40f41f2 | ||
|
|
d7d5b1b1d6 | ||
|
|
630edf2480 | ||
|
|
ea469d28d7 | ||
|
|
631ef4ed28 | ||
|
|
597f1d47b8 | ||
|
|
fcc96417f9 | ||
|
|
39986b0e97 | ||
|
|
8755211a60 | ||
|
|
62a0c358f9 | ||
|
|
87311074f1 | ||
|
|
33cf9535b3 | ||
|
|
7e6beee7f6 | ||
|
|
27b3891b14 | ||
|
|
2a864832c6 | ||
|
|
c974c12d65 | ||
|
|
50926bdbb4 | ||
|
|
bd381d59cd | ||
|
|
f67e56d3b9 | ||
|
|
8fb5a9ce11 | ||
|
|
4cdb2e533a | ||
|
|
abe8da697c | ||
|
|
039a985f41 | ||
|
|
c4a6dafd27 | ||
|
|
a930c2aecf | ||
|
|
e6d4653b08 | ||
|
|
d48edb9837 | ||
|
|
b41de7fcd1 | ||
|
|
18f84f0df5 | ||
|
|
44407a158a | ||
|
|
488b697479 | ||
|
|
5953b43ead | ||
|
|
58b2eb4b92 | ||
|
|
05415f72ec | ||
|
|
b7af53ea40 | ||
|
|
cee4aeea9e | ||
|
|
eb69f2de78 | ||
|
|
206420c085 | ||
|
|
88a864c195 | ||
|
|
ca9aca9b19 | ||
|
|
e00a280329 | ||
|
|
fe370e7d8f | ||
|
|
a789e9e6d8 | ||
|
|
9930913e4e | ||
|
|
125b5e2b16 | ||
|
|
48675f579f | ||
|
|
afec455f86 | ||
|
|
035c5d9f23 | ||
|
|
97d498c59c | ||
|
|
b2a5b29fb2 | ||
|
|
9ec61206c2 | ||
|
|
0125cd97d8 | ||
|
|
7d385b8dc3 | ||
|
|
f930ef2ee6 | ||
|
|
1b011a2d85 | ||
|
|
a85ea1ddb0 | ||
|
|
829e40d2aa | ||
|
|
6344e34880 | ||
|
|
a76ca8c565 | ||
|
|
771c99a523 | ||
|
|
26693e4ea8 | ||
|
|
e20be2397c | ||
|
|
46766e7e24 | ||
|
|
a7ddb8f1f8 | ||
|
|
7335c82553 | ||
|
|
a32ec97911 | ||
|
|
f6a71f4193 | ||
|
|
5c05131a94 | ||
|
|
b6abd4b4da | ||
|
|
2605948e01 | ||
|
|
eb2ac039c7 | ||
|
|
790a9ed7df | ||
|
|
2e61ce006d | ||
|
|
3cc485759e | ||
|
|
aafa9c67fc | ||
|
|
69f48db0a3 | ||
|
|
8c965434ae | ||
|
|
78da6b42ad | ||
|
|
1ad2cb5582 | ||
|
|
c619bf5b0c | ||
|
|
9f4db0a953 | ||
|
|
3e836db1d1 | ||
|
|
c01874e9ce | ||
|
|
1b2517ea20 | ||
|
|
3e9f0d57ac | ||
|
|
481bbe8513 | ||
|
|
bc7b2c6ba3 | ||
|
|
c6f7a299a9 | ||
|
|
992a6c79b4 | ||
|
|
78795a4a73 | ||
|
|
5a82477d48 | ||
|
|
1ffa519387 | ||
|
|
e4a25b6a60 | ||
|
|
6a6b527f24 | ||
|
|
b34887a920 | ||
|
|
b9efda3ce8 | ||
|
|
516de93627 | ||
|
|
15f0a665f8 | ||
|
|
9b5b632ff9 | ||
|
|
0c28099712 | ||
|
|
522dd44bfa | ||
|
|
8154069e77 | ||
|
|
e161a92898 | ||
|
|
3fce8485bb | ||
|
|
1cc88a2190 | ||
|
|
168ea9560e | ||
|
|
f48e33b395 | ||
|
|
f1ed8599fc | ||
|
|
93f3e1b14b | ||
|
|
649bfb236b | ||
|
|
409003b4f9 | ||
|
|
9e6e34b42d | ||
|
|
d9905d1a57 | ||
|
|
2bd68efc08 | ||
|
|
6848e1e128 | ||
|
|
668aead4c8 | ||
|
|
f08605a7f1 | ||
|
|
02a3feddb8 | ||
|
|
d9487a5749 | ||
|
|
cfa6d09c5e | ||
|
|
a01253c3c8 | ||
|
|
bc013e4888 | ||
|
|
782e3f8853 | ||
|
|
03fd656344 | ||
|
|
18b049cd24 | ||
|
|
2bdb4cb44a | ||
|
|
abbdf20f65 | ||
|
|
43ef64cf67 | ||
|
|
18316be09a | ||
|
|
1a623943c8 | ||
|
|
fbce8bb511 | ||
|
|
445b626dc8 | ||
|
|
b3c87cb5d1 | ||
|
|
0dbaddc7be | ||
|
|
ad9f044aad | ||
|
|
05930ee6b1 | ||
|
|
e670068cab | ||
|
|
b48cf1bf65 | ||
|
|
7ee7ada273 | ||
|
|
82b4e58ad0 | ||
|
|
ddc365f7a0 | ||
|
|
37ad370344 | ||
|
|
703647da1e | ||
|
|
9eff58ae62 | ||
|
|
3844516aa7 | ||
|
|
f591e47404 | ||
|
|
287ae81195 | ||
|
|
a4a30744ad | ||
|
|
dcba6a6b7e | ||
|
|
6142828a9c | ||
|
|
97bb74f824 | ||
|
|
2147bf75eb | ||
|
|
e40a29ba17 | ||
|
|
ff330e644e | ||
|
|
713e320c4c | ||
|
|
e67fe89adb | ||
|
|
6cfbb1f320 | ||
|
|
c853011a32 | ||
|
|
b50b89ba14 | ||
|
|
d063fbb8b9 | ||
|
|
e5d42bc963 | ||
|
|
8866394eb6 | ||
|
|
17c20b45ce | ||
|
|
7dacd9cb23 | ||
|
|
6285e0d23e | ||
|
|
a4826cfb5f | ||
|
|
a0bf0bdcc0 | ||
|
|
dffce78a8c | ||
|
|
c7e7ad5030 | ||
|
|
5142dc52c1 | ||
|
|
ecb44ff306 | ||
|
|
e4a5fb3e91 | ||
|
|
e52d352a48 | ||
|
|
f9723c9266 | ||
|
|
8efad1d170 | ||
|
|
c6641be94b | ||
|
|
89cf8a55e2 | ||
|
|
00c3b67182 | ||
|
|
9203690033 | ||
|
|
9683da54b0 | ||
|
|
0e48a772ff | ||
|
|
f118d81d32 | ||
|
|
ca12bc6953 | ||
|
|
9810386937 | ||
|
|
f1625b32bd | ||
|
|
0ecd5f2118 | ||
|
|
940d0c48c6 | ||
|
|
56cecf849e | ||
|
|
05c4aa7c2c | ||
|
|
2a5cb16494 |
@@ -1,4 +1,4 @@
|
|||||||
FROM golang:1.21-bullseye
|
FROM golang:1.23-bullseye
|
||||||
|
|
||||||
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
RUN apt-get update && export DEBIAN_FRONTEND=noninteractive \
|
||||||
&& apt-get -y install --no-install-recommends\
|
&& apt-get -y install --no-install-recommends\
|
||||||
|
|||||||
@@ -7,7 +7,7 @@
|
|||||||
"features": {
|
"features": {
|
||||||
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
"ghcr.io/devcontainers/features/docker-in-docker:2": {},
|
||||||
"ghcr.io/devcontainers/features/go:1": {
|
"ghcr.io/devcontainers/features/go:1": {
|
||||||
"version": "1.21"
|
"version": "1.23"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
"workspaceFolder": "/workspaces/${localWorkspaceFolderBasename}",
|
||||||
|
|||||||
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
25
.github/ISSUE_TEMPLATE/bug-issue-report.md
vendored
@@ -31,14 +31,22 @@ Please specify whether you use NetBird Cloud or self-host NetBird's control plan
|
|||||||
|
|
||||||
`netbird version`
|
`netbird version`
|
||||||
|
|
||||||
**NetBird status -dA output:**
|
**Is any other VPN software installed?**
|
||||||
|
|
||||||
If applicable, add the `netbird status -dA' command output.
|
If yes, which one?
|
||||||
|
|
||||||
**Do you face any (non-mobile) client issues?**
|
**Debug output**
|
||||||
|
|
||||||
Please provide the file created by `netbird debug for 1m -AS`.
|
To help us resolve the problem, please attach the following debug output
|
||||||
We advise reviewing the anonymized files for any remaining PII.
|
|
||||||
|
netbird status -dA
|
||||||
|
|
||||||
|
As well as the file created by
|
||||||
|
|
||||||
|
netbird debug for 1m -AS
|
||||||
|
|
||||||
|
|
||||||
|
We advise reviewing the anonymized output for any remaining personal information.
|
||||||
|
|
||||||
**Screenshots**
|
**Screenshots**
|
||||||
|
|
||||||
@@ -47,3 +55,10 @@ If applicable, add screenshots to help explain your problem.
|
|||||||
**Additional context**
|
**Additional context**
|
||||||
|
|
||||||
Add any other context about the problem here.
|
Add any other context about the problem here.
|
||||||
|
|
||||||
|
**Have you tried these troubleshooting steps?**
|
||||||
|
- [ ] Checked for newer NetBird versions
|
||||||
|
- [ ] Searched for similar issues on GitHub (including closed ones)
|
||||||
|
- [ ] Restarted the NetBird client
|
||||||
|
- [ ] Disabled other VPN software
|
||||||
|
- [ ] Checked firewall settings
|
||||||
|
|||||||
13
.github/workflows/golang-test-darwin.yml
vendored
13
.github/workflows/golang-test-darwin.yml
vendored
@@ -1,4 +1,4 @@
|
|||||||
name: Test Code Darwin
|
name: "Darwin"
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -12,15 +12,14 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
strategy:
|
name: "Client / Unit"
|
||||||
matrix:
|
|
||||||
store: ['sqlite']
|
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
@@ -28,8 +27,9 @@ jobs:
|
|||||||
uses: actions/cache@v4
|
uses: actions/cache@v4
|
||||||
with:
|
with:
|
||||||
path: ~/go/pkg/mod
|
path: ~/go/pkg/mod
|
||||||
key: macos-go-${{ hashFiles('**/go.sum') }}
|
key: macos-gotest-${{ hashFiles('**/go.sum') }}
|
||||||
restore-keys: |
|
restore-keys: |
|
||||||
|
macos-gotest-
|
||||||
macos-go-
|
macos-go-
|
||||||
|
|
||||||
- name: Install libpcap
|
- name: Install libpcap
|
||||||
@@ -42,4 +42,5 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 ./...
|
run: NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -tags=devcert -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 5m -p 1 $(go list ./... | grep -v /management)
|
||||||
|
|
||||||
|
|||||||
8
.github/workflows/golang-test-freebsd.yml
vendored
8
.github/workflows/golang-test-freebsd.yml
vendored
@@ -1,5 +1,4 @@
|
|||||||
|
name: "FreeBSD"
|
||||||
name: Test Code FreeBSD
|
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -13,6 +12,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
name: "Client / Unit"
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- uses: actions/checkout@v4
|
- uses: actions/checkout@v4
|
||||||
@@ -24,7 +24,7 @@ jobs:
|
|||||||
copyback: false
|
copyback: false
|
||||||
release: "14.1"
|
release: "14.1"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y go
|
pkg install -y go pkgconf xorg
|
||||||
|
|
||||||
# -x - to print all executed commands
|
# -x - to print all executed commands
|
||||||
# -e - to faile on first error
|
# -e - to faile on first error
|
||||||
@@ -33,7 +33,7 @@ jobs:
|
|||||||
time go build -o netbird client/main.go
|
time go build -o netbird client/main.go
|
||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -timeout 1m -failfast ./base62/...
|
||||||
# NOTE: without -p1 `client/internal/dns` will fail becasue of `listen udp4 :33100: bind: address already in use`
|
# NOTE: without -p1 `client/internal/dns` will fail because of `listen udp4 :33100: bind: address already in use`
|
||||||
time go test -timeout 8m -failfast -p 1 ./client/...
|
time go test -timeout 8m -failfast -p 1 ./client/...
|
||||||
time go test -timeout 1m -failfast ./dns/...
|
time go test -timeout 1m -failfast ./dns/...
|
||||||
time go test -timeout 1m -failfast ./encryption/...
|
time go test -timeout 1m -failfast ./encryption/...
|
||||||
|
|||||||
508
.github/workflows/golang-test-linux.yml
vendored
508
.github/workflows/golang-test-linux.yml
vendored
@@ -1,4 +1,4 @@
|
|||||||
name: Test Code Linux
|
name: Linux
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -11,31 +11,125 @@ concurrency:
|
|||||||
cancel-in-progress: true
|
cancel-in-progress: true
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
|
build-cache:
|
||||||
|
name: "Build Cache"
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
outputs:
|
||||||
|
management: ${{ steps.filter.outputs.management }}
|
||||||
|
steps:
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- uses: dorny/paths-filter@v3
|
||||||
|
id: filter
|
||||||
|
with:
|
||||||
|
filters: |
|
||||||
|
management:
|
||||||
|
- 'management/**'
|
||||||
|
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
id: cache
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Build client
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: client
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build client 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: client
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o client-386 .
|
||||||
|
|
||||||
|
- name: Build management
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: management
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build management 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: management
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o management-386 .
|
||||||
|
|
||||||
|
- name: Build signal
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: signal
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build signal 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: signal
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o signal-386 .
|
||||||
|
|
||||||
|
- name: Build relay
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: relay
|
||||||
|
run: CGO_ENABLED=1 go build .
|
||||||
|
|
||||||
|
- name: Build relay 386
|
||||||
|
if: steps.cache.outputs.cache-hit != 'true'
|
||||||
|
working-directory: relay
|
||||||
|
run: CGO_ENABLED=1 GOARCH=386 go build -o relay-386 .
|
||||||
|
|
||||||
test:
|
test:
|
||||||
|
name: "Client / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
strategy:
|
strategy:
|
||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
arch: [ '386','amd64' ]
|
arch: [ '386','amd64' ]
|
||||||
store: [ 'sqlite', 'postgres']
|
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ~/go/pkg/mod
|
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-go-
|
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
@@ -50,27 +144,401 @@ jobs:
|
|||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true go test -exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' -timeout 10m -p 1 ./...
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||||
|
|
||||||
|
test_relay:
|
||||||
|
name: "Relay / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test \
|
||||||
|
-exec 'sudo' \
|
||||||
|
-timeout 10m ./signal/...
|
||||||
|
|
||||||
|
test_signal:
|
||||||
|
name: "Signal / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ '386','amd64' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
go test \
|
||||||
|
-exec 'sudo' \
|
||||||
|
-timeout 10m ./signal/...
|
||||||
|
|
||||||
|
test_management:
|
||||||
|
name: "Management / Unit"
|
||||||
|
needs: [ build-cache ]
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ 'amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Login to Docker hub
|
||||||
|
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: download mysql image
|
||||||
|
if: matrix.store == 'mysql'
|
||||||
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
go test -tags=devcert \
|
||||||
|
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||||
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
|
benchmark:
|
||||||
|
name: "Management / Benchmark"
|
||||||
|
needs: [ build-cache ]
|
||||||
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ 'amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Login to Docker hub
|
||||||
|
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: download mysql image
|
||||||
|
if: matrix.store == 'mysql'
|
||||||
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||||
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
|
-timeout 20m ./...
|
||||||
|
|
||||||
|
api_benchmark:
|
||||||
|
name: "Management / Benchmark (API)"
|
||||||
|
needs: [ build-cache ]
|
||||||
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ 'amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres' ]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Login to Docker hub
|
||||||
|
if: matrix.store == 'mysql' && (github.repository == github.head.repo.full_name || !github.head_ref)
|
||||||
|
uses: docker/login-action@v1
|
||||||
|
with:
|
||||||
|
username: ${{ secrets.DOCKER_USER }}
|
||||||
|
password: ${{ secrets.DOCKER_TOKEN }}
|
||||||
|
|
||||||
|
- name: download mysql image
|
||||||
|
if: matrix.store == 'mysql'
|
||||||
|
run: docker pull mlsmaycon/warmed-mysql:8
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||||
|
go test -tags=benchmark \
|
||||||
|
-run=^$ \
|
||||||
|
-bench=. \
|
||||||
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
|
api_integration_test:
|
||||||
|
name: "Management / Integration"
|
||||||
|
needs: [ build-cache ]
|
||||||
|
if: ${{ needs.build-cache.outputs.management == 'true' || github.event_name != 'pull_request' }}
|
||||||
|
strategy:
|
||||||
|
fail-fast: false
|
||||||
|
matrix:
|
||||||
|
arch: [ 'amd64' ]
|
||||||
|
store: [ 'sqlite', 'postgres']
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Install dependencies
|
||||||
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
- name: Install 32-bit libpcap
|
||||||
|
if: matrix.arch == '386'
|
||||||
|
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
||||||
|
|
||||||
|
- name: Install modules
|
||||||
|
run: go mod tidy
|
||||||
|
|
||||||
|
- name: check git status
|
||||||
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
|
- name: Test
|
||||||
|
run: |
|
||||||
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
||||||
|
go test -tags=integration \
|
||||||
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
|
name: "Client (Docker) / Unit"
|
||||||
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-20.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache@v4
|
|
||||||
with:
|
|
||||||
path: ~/go/pkg/mod
|
|
||||||
key: ${{ runner.os }}-go-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-go-
|
|
||||||
|
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
- name: Install dependencies
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
||||||
|
|
||||||
|
|||||||
28
.github/workflows/golang-test-windows.yml
vendored
28
.github/workflows/golang-test-windows.yml
vendored
@@ -1,4 +1,4 @@
|
|||||||
name: Test Code Windows
|
name: "Windows"
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -14,6 +14,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test:
|
test:
|
||||||
|
name: "Client / Unit"
|
||||||
runs-on: windows-latest
|
runs-on: windows-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout code
|
- name: Checkout code
|
||||||
@@ -24,6 +25,23 @@ jobs:
|
|||||||
id: go
|
id: go
|
||||||
with:
|
with:
|
||||||
go-version: "1.23.x"
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
run: |
|
||||||
|
echo "cache=$(go env GOCACHE)" >> $env:GITHUB_ENV
|
||||||
|
echo "modcache=$(go env GOMODCACHE)" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache@v4
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ env.cache }}
|
||||||
|
${{ env.modcache }}
|
||||||
|
key: ${{ runner.os }}-gotest-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-
|
||||||
|
${{ runner.os }}-go-
|
||||||
|
|
||||||
- name: Download wintun
|
- name: Download wintun
|
||||||
uses: carlosperate/download-file-action@v2
|
uses: carlosperate/download-file-action@v2
|
||||||
@@ -42,11 +60,13 @@ jobs:
|
|||||||
- run: choco install -y sysinternals --ignore-checksums
|
- run: choco install -y sysinternals --ignore-checksums
|
||||||
- run: choco install -y mingw
|
- run: choco install -y mingw
|
||||||
|
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=C:\Users\runneradmin\go\pkg\mod
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOMODCACHE=${{ env.cache }}
|
||||||
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=C:\Users\runneradmin\AppData\Local\go-build
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe env -w GOCACHE=${{ env.modcache }}
|
||||||
|
- run: PsExec64 -s -w ${{ github.workspace }} C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe mod tidy
|
||||||
|
- run: echo "files=$(go list ./... | ForEach-Object { $_ } | Where-Object { $_ -notmatch '/management' })" >> $env:GITHUB_ENV
|
||||||
|
|
||||||
- name: test
|
- name: test
|
||||||
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -timeout 10m -p 1 ./... > test-out.txt 2>&1"
|
run: PsExec64 -s -w ${{ github.workspace }} cmd.exe /c "C:\hostedtoolcache\windows\go\${{ steps.go.outputs.go-version }}\x64\bin\go.exe test -tags=devcert -timeout 10m -p 1 ${{ env.files }} > test-out.txt 2>&1"
|
||||||
- name: test output
|
- name: test output
|
||||||
if: ${{ always() }}
|
if: ${{ always() }}
|
||||||
run: Get-Content test-out.txt
|
run: Get-Content test-out.txt
|
||||||
|
|||||||
17
.github/workflows/golangci-lint.yml
vendored
17
.github/workflows/golangci-lint.yml
vendored
@@ -1,4 +1,4 @@
|
|||||||
name: golangci-lint
|
name: Lint
|
||||||
on: [pull_request]
|
on: [pull_request]
|
||||||
|
|
||||||
permissions:
|
permissions:
|
||||||
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
@@ -27,7 +27,14 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
os: [macos-latest, windows-latest, ubuntu-latest]
|
os: [macos-latest, windows-latest, ubuntu-latest]
|
||||||
name: lint
|
include:
|
||||||
|
- os: macos-latest
|
||||||
|
display_name: Darwin
|
||||||
|
- os: windows-latest
|
||||||
|
display_name: Windows
|
||||||
|
- os: ubuntu-latest
|
||||||
|
display_name: Linux
|
||||||
|
name: ${{ matrix.display_name }}
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
timeout-minutes: 15
|
timeout-minutes: 15
|
||||||
steps:
|
steps:
|
||||||
@@ -46,7 +53,7 @@ jobs:
|
|||||||
if: matrix.os == 'ubuntu-latest'
|
if: matrix.os == 'ubuntu-latest'
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev libpcap-dev
|
||||||
- name: golangci-lint
|
- name: golangci-lint
|
||||||
uses: golangci/golangci-lint-action@v3
|
uses: golangci/golangci-lint-action@v4
|
||||||
with:
|
with:
|
||||||
version: latest
|
version: latest
|
||||||
args: --timeout=12m
|
args: --timeout=12m --out-format colored-line-number
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
name: Mobile build validation
|
name: Mobile
|
||||||
|
|
||||||
on:
|
on:
|
||||||
push:
|
push:
|
||||||
@@ -12,6 +12,7 @@ concurrency:
|
|||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
android_build:
|
android_build:
|
||||||
|
name: "Android / Build"
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
@@ -47,6 +48,7 @@ jobs:
|
|||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
ANDROID_NDK_HOME: /usr/local/lib/android/sdk/ndk/23.1.7779620
|
||||||
ios_build:
|
ios_build:
|
||||||
|
name: "iOS / Build"
|
||||||
runs-on: macos-latest
|
runs-on: macos-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
|
|||||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -9,10 +9,10 @@ on:
|
|||||||
pull_request:
|
pull_request:
|
||||||
|
|
||||||
env:
|
env:
|
||||||
SIGN_PIPE_VER: "v0.0.17"
|
SIGN_PIPE_VER: "v0.0.18"
|
||||||
GORELEASER_VER: "v2.3.2"
|
GORELEASER_VER: "v2.3.2"
|
||||||
PRODUCT_NAME: "NetBird"
|
PRODUCT_NAME: "NetBird"
|
||||||
COPYRIGHT: "Wiretrustee UG (haftungsbeschreankt)"
|
COPYRIGHT: "NetBird GmbH"
|
||||||
|
|
||||||
concurrency:
|
concurrency:
|
||||||
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
group: ${{ github.workflow }}-${{ github.ref }}-${{ github.head_ref || github.actor_id }}
|
||||||
@@ -71,7 +71,7 @@ jobs:
|
|||||||
- name: Install goversioninfo
|
- name: Install goversioninfo
|
||||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||||
- name: Generate windows syso amd64
|
- name: Generate windows syso amd64
|
||||||
run: goversioninfo -icon client/ui/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
run: goversioninfo -icon client/ui/assets/netbird.ico -manifest client/manifest.xml -product-name ${{ env.PRODUCT_NAME }} -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/resources_windows_amd64.syso
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
with:
|
with:
|
||||||
@@ -150,7 +150,7 @@ jobs:
|
|||||||
- name: Install goversioninfo
|
- name: Install goversioninfo
|
||||||
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
run: go install github.com/josephspurrier/goversioninfo/cmd/goversioninfo@233067e
|
||||||
- name: Generate windows syso amd64
|
- name: Generate windows syso amd64
|
||||||
run: goversioninfo -64 -icon client/ui/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
run: goversioninfo -64 -icon client/ui/assets/netbird.ico -manifest client/ui/manifest.xml -product-name ${{ env.PRODUCT_NAME }}-"UI" -copyright "${{ env.COPYRIGHT }}" -ver-major ${{ steps.semver_parser.outputs.major }} -ver-minor ${{ steps.semver_parser.outputs.minor }} -ver-patch ${{ steps.semver_parser.outputs.patch }} -ver-build 0 -file-version ${{ steps.semver_parser.outputs.fullversion }}.0 -product-version ${{ steps.semver_parser.outputs.fullversion }}.0 -o client/ui/resources_windows_amd64.syso
|
||||||
|
|
||||||
- name: Run GoReleaser
|
- name: Run GoReleaser
|
||||||
uses: goreleaser/goreleaser-action@v4
|
uses: goreleaser/goreleaser-action@v4
|
||||||
|
|||||||
23
.github/workflows/test-infrastructure-files.yml
vendored
23
.github/workflows/test-infrastructure-files.yml
vendored
@@ -20,7 +20,7 @@ jobs:
|
|||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
strategy:
|
strategy:
|
||||||
matrix:
|
matrix:
|
||||||
store: [ 'sqlite', 'postgres' ]
|
store: [ 'sqlite', 'postgres', 'mysql' ]
|
||||||
services:
|
services:
|
||||||
postgres:
|
postgres:
|
||||||
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
image: ${{ (matrix.store == 'postgres') && 'postgres' || '' }}
|
||||||
@@ -34,6 +34,19 @@ jobs:
|
|||||||
--health-timeout 5s
|
--health-timeout 5s
|
||||||
ports:
|
ports:
|
||||||
- 5432:5432
|
- 5432:5432
|
||||||
|
mysql:
|
||||||
|
image: ${{ (matrix.store == 'mysql') && 'mysql' || '' }}
|
||||||
|
env:
|
||||||
|
MYSQL_USER: netbird
|
||||||
|
MYSQL_PASSWORD: mysql
|
||||||
|
MYSQL_ROOT_PASSWORD: mysqlroot
|
||||||
|
MYSQL_DATABASE: netbird
|
||||||
|
options: >-
|
||||||
|
--health-cmd "mysqladmin ping --silent"
|
||||||
|
--health-interval 10s
|
||||||
|
--health-timeout 5s
|
||||||
|
ports:
|
||||||
|
- 3306:3306
|
||||||
steps:
|
steps:
|
||||||
- name: Set Database Connection String
|
- name: Set Database Connection String
|
||||||
run: |
|
run: |
|
||||||
@@ -42,6 +55,11 @@ jobs:
|
|||||||
else
|
else
|
||||||
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
|
echo "NETBIRD_STORE_ENGINE_POSTGRES_DSN==" >> $GITHUB_ENV
|
||||||
fi
|
fi
|
||||||
|
if [ "${{ matrix.store }}" == "mysql" ]; then
|
||||||
|
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN=netbird:mysql@tcp($(hostname -I | awk '{print $1}'):3306)/netbird" >> $GITHUB_ENV
|
||||||
|
else
|
||||||
|
echo "NETBIRD_STORE_ENGINE_MYSQL_DSN==" >> $GITHUB_ENV
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Install jq
|
- name: Install jq
|
||||||
run: sudo apt-get install -y jq
|
run: sudo apt-get install -y jq
|
||||||
@@ -84,6 +102,7 @@ jobs:
|
|||||||
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
CI_NETBIRD_AUTH_SUPPORTED_SCOPES: "openid profile email offline_access api email_verified"
|
||||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
|
NETBIRD_STORE_ENGINE_POSTGRES_DSN: ${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}
|
||||||
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: ${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
|
|
||||||
- name: check values
|
- name: check values
|
||||||
@@ -112,6 +131,7 @@ jobs:
|
|||||||
CI_NETBIRD_SIGNAL_PORT: 12345
|
CI_NETBIRD_SIGNAL_PORT: 12345
|
||||||
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
CI_NETBIRD_STORE_CONFIG_ENGINE: ${{ matrix.store }}
|
||||||
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
NETBIRD_STORE_ENGINE_POSTGRES_DSN: '${{ env.NETBIRD_STORE_ENGINE_POSTGRES_DSN }}$'
|
||||||
|
NETBIRD_STORE_ENGINE_MYSQL_DSN: '${{ env.NETBIRD_STORE_ENGINE_MYSQL_DSN }}$'
|
||||||
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
CI_NETBIRD_MGMT_IDP_SIGNKEY_REFRESH: false
|
||||||
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
CI_NETBIRD_TURN_EXTERNAL_IP: "1.2.3.4"
|
||||||
|
|
||||||
@@ -149,6 +169,7 @@ jobs:
|
|||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep Scope | grep "$CI_NETBIRD_AUTH_SUPPORTED_SCOPES"
|
||||||
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
grep -A 10 PKCEAuthorizationFlow management.json | grep -A 10 ProviderConfig | grep -A 3 RedirectURLs | grep "http://localhost:53000"
|
||||||
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
grep "external-ip" turnserver.conf | grep $CI_NETBIRD_TURN_EXTERNAL_IP
|
||||||
|
grep "NETBIRD_STORE_ENGINE_MYSQL_DSN=$NETBIRD_STORE_ENGINE_MYSQL_DSN" docker-compose.yml
|
||||||
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
grep NETBIRD_STORE_ENGINE_POSTGRES_DSN docker-compose.yml | egrep "$NETBIRD_STORE_ENGINE_POSTGRES_DSN"
|
||||||
# check relay values
|
# check relay values
|
||||||
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
grep "NB_EXPOSED_ADDRESS=$CI_NETBIRD_DOMAIN:33445" docker-compose.yml
|
||||||
|
|||||||
1
.gitignore
vendored
1
.gitignore
vendored
@@ -29,3 +29,4 @@ infrastructure_files/setup.env
|
|||||||
infrastructure_files/setup-*.env
|
infrastructure_files/setup-*.env
|
||||||
.vscode
|
.vscode
|
||||||
.DS_Store
|
.DS_Store
|
||||||
|
vendor/
|
||||||
|
|||||||
@@ -103,7 +103,7 @@ linters:
|
|||||||
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
- predeclared # predeclared finds code that shadows one of Go's predeclared identifiers
|
||||||
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
- revive # Fast, configurable, extensible, flexible, and beautiful linter for Go. Drop-in replacement of golint.
|
||||||
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
- sqlclosecheck # checks that sql.Rows and sql.Stmt are closed
|
||||||
- thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
# - thelper # thelper detects Go test helpers without t.Helper() call and checks the consistency of test helpers.
|
||||||
- wastedassign # wastedassign finds wasted assignment statements
|
- wastedassign # wastedassign finds wasted assignment statements
|
||||||
issues:
|
issues:
|
||||||
# Maximum count of issues with the same text.
|
# Maximum count of issues with the same text.
|
||||||
|
|||||||
@@ -179,6 +179,51 @@ dockers:
|
|||||||
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
- "--label=org.opencontainers.image.version={{.Version}}"
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
- "--label=maintainer=dev@netbird.io"
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
ids:
|
||||||
|
- netbird
|
||||||
|
goarch: amd64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: client/Dockerfile-rootless
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/amd64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
ids:
|
||||||
|
- netbird
|
||||||
|
goarch: arm64
|
||||||
|
use: buildx
|
||||||
|
dockerfile: client/Dockerfile-rootless
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm64"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
- image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
ids:
|
||||||
|
- netbird
|
||||||
|
goarch: arm
|
||||||
|
goarm: 6
|
||||||
|
use: buildx
|
||||||
|
dockerfile: client/Dockerfile-rootless
|
||||||
|
build_flag_templates:
|
||||||
|
- "--platform=linux/arm"
|
||||||
|
- "--label=org.opencontainers.image.created={{.Date}}"
|
||||||
|
- "--label=org.opencontainers.image.title={{.ProjectName}}"
|
||||||
|
- "--label=org.opencontainers.image.version={{.Version}}"
|
||||||
|
- "--label=org.opencontainers.image.revision={{.FullCommit}}"
|
||||||
|
- "--label=maintainer=dev@netbird.io"
|
||||||
|
|
||||||
- image_templates:
|
- image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-amd64
|
- netbirdio/relay:{{ .Version }}-amd64
|
||||||
ids:
|
ids:
|
||||||
@@ -377,6 +422,18 @@ docker_manifests:
|
|||||||
- netbirdio/netbird:{{ .Version }}-arm
|
- netbirdio/netbird:{{ .Version }}-arm
|
||||||
- netbirdio/netbird:{{ .Version }}-amd64
|
- netbirdio/netbird:{{ .Version }}-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird:{{ .Version }}-rootless
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
|
- name_template: netbirdio/netbird:rootless-latest
|
||||||
|
image_templates:
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm64v8
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-arm
|
||||||
|
- netbirdio/netbird:{{ .Version }}-rootless-amd64
|
||||||
|
|
||||||
- name_template: netbirdio/relay:{{ .Version }}
|
- name_template: netbirdio/relay:{{ .Version }}
|
||||||
image_templates:
|
image_templates:
|
||||||
- netbirdio/relay:{{ .Version }}-arm64v8
|
- netbirdio/relay:{{ .Version }}-arm64v8
|
||||||
|
|||||||
@@ -50,10 +50,12 @@ nfpms:
|
|||||||
- netbird-ui
|
- netbird-ui
|
||||||
formats:
|
formats:
|
||||||
- deb
|
- deb
|
||||||
|
scripts:
|
||||||
|
postinstall: "release_files/ui-post-install.sh"
|
||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/build/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-connected.png
|
- src: client/ui/assets/netbird.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
@@ -67,10 +69,12 @@ nfpms:
|
|||||||
- netbird-ui
|
- netbird-ui
|
||||||
formats:
|
formats:
|
||||||
- rpm
|
- rpm
|
||||||
|
scripts:
|
||||||
|
postinstall: "release_files/ui-post-install.sh"
|
||||||
contents:
|
contents:
|
||||||
- src: client/ui/netbird.desktop
|
- src: client/ui/build/netbird.desktop
|
||||||
dst: /usr/share/applications/netbird.desktop
|
dst: /usr/share/applications/netbird.desktop
|
||||||
- src: client/ui/netbird-systemtray-connected.png
|
- src: client/ui/assets/netbird.png
|
||||||
dst: /usr/share/pixmaps/netbird.png
|
dst: /usr/share/pixmaps/netbird.png
|
||||||
dependencies:
|
dependencies:
|
||||||
- netbird
|
- netbird
|
||||||
|
|||||||
2
AUTHORS
2
AUTHORS
@@ -1,3 +1,3 @@
|
|||||||
Mikhail Bragin (https://github.com/braginini)
|
Mikhail Bragin (https://github.com/braginini)
|
||||||
Maycon Santos (https://github.com/mlsmaycon)
|
Maycon Santos (https://github.com/mlsmaycon)
|
||||||
Wiretrustee UG (haftungsbeschränkt)
|
NetBird GmbH
|
||||||
|
|||||||
@@ -3,10 +3,10 @@
|
|||||||
We are incredibly thankful for the contributions we receive from the community.
|
We are incredibly thankful for the contributions we receive from the community.
|
||||||
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
||||||
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
||||||
as BSD-3 while allowing Wiretrustee to build a sustainable business.
|
as BSD-3 while allowing NetBird to build a sustainable business.
|
||||||
|
|
||||||
Wiretrustee is committed to having a true Open Source Software ("OSS") license for
|
NetBird is committed to having a true Open Source Software ("OSS") license for
|
||||||
our software. A CLA enables Wiretrustee to safely commercialize our products
|
our software. A CLA enables NetBird to safely commercialize our products
|
||||||
while keeping a standard OSS license with all the rights that license grants to users: the
|
while keeping a standard OSS license with all the rights that license grants to users: the
|
||||||
ability to use the project in their own projects or businesses, to republish modified
|
ability to use the project in their own projects or businesses, to republish modified
|
||||||
source, or to completely fork the project.
|
source, or to completely fork the project.
|
||||||
@@ -20,11 +20,11 @@ This is a human-readable summary of (and not a substitute for) the full agreemen
|
|||||||
This highlights only some of key terms of the CLA. It has no legal value and you should
|
This highlights only some of key terms of the CLA. It has no legal value and you should
|
||||||
carefully review all the terms of the actual CLA before agreeing.
|
carefully review all the terms of the actual CLA before agreeing.
|
||||||
|
|
||||||
<li>Grant of copyright license. You give Wiretrustee permission to use your copyrighted work
|
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
|
||||||
in commercial products.
|
in commercial products.
|
||||||
</li>
|
</li>
|
||||||
|
|
||||||
<li>Grant of patent license. If your contributed work uses a patent, you give Wiretrustee a
|
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
|
||||||
license to use that patent including within commercial products. You also agree that you
|
license to use that patent including within commercial products. You also agree that you
|
||||||
have permission to grant this license.
|
have permission to grant this license.
|
||||||
</li>
|
</li>
|
||||||
@@ -45,7 +45,7 @@ more.
|
|||||||
# Why require a CLA?
|
# Why require a CLA?
|
||||||
|
|
||||||
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
||||||
to use your contribution at a later date, and that Wiretrustee has permission to use your contribution in our commercial
|
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
|
||||||
products.
|
products.
|
||||||
|
|
||||||
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
||||||
@@ -65,25 +65,25 @@ Follow the steps given by the bot to sign the CLA. This will require you to log
|
|||||||
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
||||||
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
||||||
|
|
||||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any Wiretrustee project will not
|
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
|
||||||
require you to sign again.
|
require you to sign again.
|
||||||
|
|
||||||
# Legal Terms and Agreement
|
# Legal Terms and Agreement
|
||||||
|
|
||||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, Wiretrustee
|
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
|
||||||
UG (haftungsbeschränkt) ("Wiretrustee") must have a Contributor License Agreement ("CLA") on file that has been signed
|
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
|
||||||
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
||||||
your own Contributions for any other purpose.
|
your own Contributions for any other purpose.
|
||||||
|
|
||||||
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
||||||
Wiretrustee. Except for the license granted herein to Wiretrustee and recipients of software distributed by Wiretrustee,
|
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
|
||||||
You reserve all right, title, and interest in and to Your Contributions.
|
You reserve all right, title, and interest in and to Your Contributions.
|
||||||
|
|
||||||
1. Definitions.
|
1. Definitions.
|
||||||
|
|
||||||
```
|
```
|
||||||
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
||||||
that is making this Agreement with Wiretrustee. For legal entities, the entity making a Contribution and all other
|
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
|
||||||
entities that control, are controlled by, or are under common control with that entity are considered
|
entities that control, are controlled by, or are under common control with that entity are considered
|
||||||
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
||||||
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
||||||
@@ -91,23 +91,23 @@ You reserve all right, title, and interest in and to Your Contributions.
|
|||||||
```
|
```
|
||||||
```
|
```
|
||||||
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
||||||
an existing work, that is or previously has been intentionally submitted by You to Wiretrustee for inclusion in,
|
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
|
||||||
or documentation of, any of the products owned or managed by Wiretrustee (the "Work").
|
or documentation of, any of the products owned or managed by NetBird (the "Work").
|
||||||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
||||||
sent to Wiretrustee or its representatives, including but not limited to communication on electronic mailing lists,
|
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
|
||||||
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
||||||
Wiretrustee for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
||||||
marked or otherwise designated in writing by You as "Not a Contribution."
|
marked or otherwise designated in writing by You as "Not a Contribution."
|
||||||
```
|
```
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee
|
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
|
||||||
and to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge,
|
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
|
||||||
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
||||||
perform, sublicense, and distribute Your Contributions and such derivative works.
|
perform, sublicense, and distribute Your Contributions and such derivative works.
|
||||||
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to Wiretrustee and
|
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
|
||||||
to recipients of software distributed by Wiretrustee a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
||||||
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
||||||
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
||||||
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
||||||
@@ -121,8 +121,8 @@ You reserve all right, title, and interest in and to Your Contributions.
|
|||||||
intellectual property that you create that includes your Contributions, you represent that you have received
|
intellectual property that you create that includes your Contributions, you represent that you have received
|
||||||
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
||||||
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
||||||
your current and future Contributions to Wiretrustee, or that your employer has executed a separate Corporate CLA
|
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
|
||||||
with Wiretrustee.
|
with NetBird.
|
||||||
|
|
||||||
|
|
||||||
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
||||||
@@ -138,11 +138,11 @@ You reserve all right, title, and interest in and to Your Contributions.
|
|||||||
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
||||||
|
|
||||||
|
|
||||||
7. Should You wish to submit work that is not Your original creation, You may submit it to Wiretrustee separately from
|
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
|
||||||
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
||||||
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
||||||
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||||
|
|
||||||
|
|
||||||
8. You agree to notify Wiretrustee of any facts or circumstances of which you become aware that would make these
|
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
|
||||||
representations inaccurate in any respect.
|
representations inaccurate in any respect.
|
||||||
2
LICENSE
2
LICENSE
@@ -1,6 +1,6 @@
|
|||||||
BSD 3-Clause License
|
BSD 3-Clause License
|
||||||
|
|
||||||
Copyright (c) 2022 Wiretrustee UG (haftungsbeschränkt) & AUTHORS
|
Copyright (c) 2022 NetBird GmbH & AUTHORS
|
||||||
|
|
||||||
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
|
||||||
|
|
||||||
|
|||||||
17
README.md
17
README.md
@@ -1,11 +1,6 @@
|
|||||||
<p align="center">
|
|
||||||
<strong>:hatching_chick: New Release! Device Posture Checks.</strong>
|
|
||||||
<a href="https://docs.netbird.io/how-to/manage-posture-checks">
|
|
||||||
Learn more
|
|
||||||
</a>
|
|
||||||
</p>
|
|
||||||
<br/>
|
|
||||||
<div align="center">
|
<div align="center">
|
||||||
|
<br/>
|
||||||
|
<br/>
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img width="234" src="docs/media/logo-full.png"/>
|
<img width="234" src="docs/media/logo-full.png"/>
|
||||||
</p>
|
</p>
|
||||||
@@ -17,7 +12,7 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">
|
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
@@ -34,10 +29,14 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2p5zwhm4g-8fHollzrQa5y4PZF5AEpvQ">Slack channel</a>
|
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
|
<br>
|
||||||
|
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
|
||||||
|
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
|
||||||
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
<br>
|
<br>
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
FROM alpine:3.20
|
FROM alpine:3.21.3
|
||||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
RUN apk add --no-cache ca-certificates iptables ip6tables
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||||
|
|||||||
17
client/Dockerfile-rootless
Normal file
17
client/Dockerfile-rootless
Normal file
@@ -0,0 +1,17 @@
|
|||||||
|
FROM alpine:3.21.0
|
||||||
|
|
||||||
|
COPY netbird /usr/local/bin/netbird
|
||||||
|
|
||||||
|
RUN apk add --no-cache ca-certificates \
|
||||||
|
&& adduser -D -h /var/lib/netbird netbird
|
||||||
|
WORKDIR /var/lib/netbird
|
||||||
|
USER netbird:netbird
|
||||||
|
|
||||||
|
ENV NB_FOREGROUND_MODE=true
|
||||||
|
ENV NB_USE_NETSTACK_MODE=true
|
||||||
|
ENV NB_ENABLE_NETSTACK_LOCAL_FORWARDING=true
|
||||||
|
ENV NB_CONFIG=config.json
|
||||||
|
ENV NB_DAEMON_ADDR=unix://netbird.sock
|
||||||
|
ENV NB_DISABLE_DNS=true
|
||||||
|
|
||||||
|
ENTRYPOINT [ "/usr/local/bin/netbird", "up" ]
|
||||||
@@ -162,7 +162,7 @@ func (a *Auth) login(urlOpener URLOpener) error {
|
|||||||
|
|
||||||
// check if we need to generate JWT token
|
// check if we need to generate JWT token
|
||||||
err := a.withBackOff(a.ctx, func() (err error) {
|
err := a.withBackOff(a.ctx, func() (err error) {
|
||||||
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config.PrivateKey, a.config.ManagementURL, a.config.SSHKey)
|
needsLogin, err = internal.IsLoginRequired(a.ctx, a.config)
|
||||||
return
|
return
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -12,6 +12,8 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const anonTLD = ".domain"
|
||||||
|
|
||||||
type Anonymizer struct {
|
type Anonymizer struct {
|
||||||
ipAnonymizer map[netip.Addr]netip.Addr
|
ipAnonymizer map[netip.Addr]netip.Addr
|
||||||
domainAnonymizer map[string]string
|
domainAnonymizer map[string]string
|
||||||
@@ -19,6 +21,8 @@ type Anonymizer struct {
|
|||||||
currentAnonIPv6 netip.Addr
|
currentAnonIPv6 netip.Addr
|
||||||
startAnonIPv4 netip.Addr
|
startAnonIPv4 netip.Addr
|
||||||
startAnonIPv6 netip.Addr
|
startAnonIPv6 netip.Addr
|
||||||
|
|
||||||
|
domainKeyRegex *regexp.Regexp
|
||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||||
@@ -34,6 +38,8 @@ func NewAnonymizer(startIPv4, startIPv6 netip.Addr) *Anonymizer {
|
|||||||
currentAnonIPv6: startIPv6,
|
currentAnonIPv6: startIPv6,
|
||||||
startAnonIPv4: startIPv4,
|
startAnonIPv4: startIPv4,
|
||||||
startAnonIPv6: startIPv6,
|
startAnonIPv6: startIPv6,
|
||||||
|
|
||||||
|
domainKeyRegex: regexp.MustCompile(`\bdomain=([^\s,:"]+)`),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -83,29 +89,39 @@ func (a *Anonymizer) AnonymizeIPString(ip string) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (a *Anonymizer) AnonymizeDomain(domain string) string {
|
func (a *Anonymizer) AnonymizeDomain(domain string) string {
|
||||||
if strings.HasSuffix(domain, "netbird.io") ||
|
baseDomain := domain
|
||||||
strings.HasSuffix(domain, "netbird.selfhosted") ||
|
hasDot := strings.HasSuffix(domain, ".")
|
||||||
strings.HasSuffix(domain, "netbird.cloud") ||
|
if hasDot {
|
||||||
strings.HasSuffix(domain, "netbird.stage") ||
|
baseDomain = domain[:len(domain)-1]
|
||||||
strings.HasSuffix(domain, ".domain") {
|
}
|
||||||
|
|
||||||
|
if strings.HasSuffix(baseDomain, "netbird.io") ||
|
||||||
|
strings.HasSuffix(baseDomain, "netbird.selfhosted") ||
|
||||||
|
strings.HasSuffix(baseDomain, "netbird.cloud") ||
|
||||||
|
strings.HasSuffix(baseDomain, "netbird.stage") ||
|
||||||
|
strings.HasSuffix(baseDomain, anonTLD) {
|
||||||
return domain
|
return domain
|
||||||
}
|
}
|
||||||
|
|
||||||
parts := strings.Split(domain, ".")
|
parts := strings.Split(baseDomain, ".")
|
||||||
if len(parts) < 2 {
|
if len(parts) < 2 {
|
||||||
return domain
|
return domain
|
||||||
}
|
}
|
||||||
|
|
||||||
baseDomain := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
baseForLookup := parts[len(parts)-2] + "." + parts[len(parts)-1]
|
||||||
|
|
||||||
anonymized, ok := a.domainAnonymizer[baseDomain]
|
anonymized, ok := a.domainAnonymizer[baseForLookup]
|
||||||
if !ok {
|
if !ok {
|
||||||
anonymizedBase := "anon-" + generateRandomString(5) + ".domain"
|
anonymizedBase := "anon-" + generateRandomString(5) + anonTLD
|
||||||
a.domainAnonymizer[baseDomain] = anonymizedBase
|
a.domainAnonymizer[baseForLookup] = anonymizedBase
|
||||||
anonymized = anonymizedBase
|
anonymized = anonymizedBase
|
||||||
}
|
}
|
||||||
|
|
||||||
return strings.Replace(domain, baseDomain, anonymized, 1)
|
result := strings.Replace(baseDomain, baseForLookup, anonymized, 1)
|
||||||
|
if hasDot {
|
||||||
|
result += "."
|
||||||
|
}
|
||||||
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Anonymizer) AnonymizeURI(uri string) string {
|
func (a *Anonymizer) AnonymizeURI(uri string) string {
|
||||||
@@ -152,27 +168,22 @@ func (a *Anonymizer) AnonymizeString(str string) string {
|
|||||||
return str
|
return str
|
||||||
}
|
}
|
||||||
|
|
||||||
// AnonymizeSchemeURI finds and anonymizes URIs with stun, stuns, turn, and turns schemes.
|
// AnonymizeSchemeURI finds and anonymizes URIs with ws, wss, rel, rels, stun, stuns, turn, and turns schemes.
|
||||||
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
func (a *Anonymizer) AnonymizeSchemeURI(text string) string {
|
||||||
re := regexp.MustCompile(`(?i)\b(stuns?:|turns?:|https?://)\S+\b`)
|
re := regexp.MustCompile(`(?i)\b(wss?://|rels?://|stuns?:|turns?:|https?://)\S+\b`)
|
||||||
|
|
||||||
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
return re.ReplaceAllStringFunc(text, a.AnonymizeURI)
|
||||||
}
|
}
|
||||||
|
|
||||||
// AnonymizeDNSLogLine anonymizes domain names in DNS log entries by replacing them with a random string.
|
|
||||||
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
func (a *Anonymizer) AnonymizeDNSLogLine(logEntry string) string {
|
||||||
domainPattern := `dns\.Question{Name:"([^"]+)",`
|
return a.domainKeyRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
||||||
domainRegex := regexp.MustCompile(domainPattern)
|
parts := strings.SplitN(match, "=", 2)
|
||||||
|
|
||||||
return domainRegex.ReplaceAllStringFunc(logEntry, func(match string) string {
|
|
||||||
parts := strings.Split(match, `"`)
|
|
||||||
if len(parts) >= 2 {
|
if len(parts) >= 2 {
|
||||||
domain := parts[1]
|
domain := parts[1]
|
||||||
if strings.HasSuffix(domain, ".domain") {
|
if strings.HasSuffix(domain, anonTLD) {
|
||||||
return match
|
return match
|
||||||
}
|
}
|
||||||
randomDomain := generateRandomString(10) + ".domain"
|
return "domain=" + a.AnonymizeDomain(domain)
|
||||||
return strings.Replace(match, domain, randomDomain, 1)
|
|
||||||
}
|
}
|
||||||
return match
|
return match
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -46,11 +46,59 @@ func TestAnonymizeIP(t *testing.T) {
|
|||||||
|
|
||||||
func TestAnonymizeDNSLogLine(t *testing.T) {
|
func TestAnonymizeDNSLogLine(t *testing.T) {
|
||||||
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
anonymizer := anonymize.NewAnonymizer(netip.Addr{}, netip.Addr{})
|
||||||
testLog := `2024-04-23T20:01:11+02:00 TRAC client/internal/dns/local.go:25: received question: dns.Question{Name:"example.com", Qtype:0x1c, Qclass:0x1}`
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
input string
|
||||||
|
original string
|
||||||
|
expect string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic domain with trailing content",
|
||||||
|
input: "received DNS request for DNS forwarder: domain=example.com: something happened with code=123",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `received DNS request for DNS forwarder: domain=anon-[a-zA-Z0-9]+\.domain: something happened with code=123`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Domain with trailing dot",
|
||||||
|
input: "domain=example.com. processing request with status=pending",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `domain=anon-[a-zA-Z0-9]+\.domain\. processing request with status=pending`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple domains in log",
|
||||||
|
input: "forward domain=first.com status=ok, redirect to domain=second.com port=443",
|
||||||
|
original: "first.com", // testing just one is sufficient as AnonymizeDomain is tested separately
|
||||||
|
expect: `forward domain=anon-[a-zA-Z0-9]+\.domain status=ok, redirect to domain=anon-[a-zA-Z0-9]+\.domain port=443`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Already anonymized domain",
|
||||||
|
input: "got request domain=anon-xyz123.domain from=client1 to=server2",
|
||||||
|
original: "", // nothing should be anonymized
|
||||||
|
expect: `got request domain=anon-xyz123\.domain from=client1 to=server2`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Subdomain with trailing dot",
|
||||||
|
input: "domain=sub.example.com. next_hop=10.0.0.1 proto=udp",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `domain=sub\.anon-[a-zA-Z0-9]+\.domain\. next_hop=10\.0\.0\.1 proto=udp`,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Handler chain pattern log",
|
||||||
|
input: "pattern: domain=example.com. original: domain=*.example.com. wildcard=true priority=100",
|
||||||
|
original: "example.com",
|
||||||
|
expect: `pattern: domain=anon-[a-zA-Z0-9]+\.domain\. original: domain=\*\.anon-[a-zA-Z0-9]+\.domain\. wildcard=true priority=100`,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
result := anonymizer.AnonymizeDNSLogLine(testLog)
|
for _, tc := range tests {
|
||||||
require.NotEqual(t, testLog, result)
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
assert.NotContains(t, result, "example.com")
|
result := anonymizer.AnonymizeDNSLogLine(tc.input)
|
||||||
|
if tc.original != "" {
|
||||||
|
assert.NotContains(t, result, tc.original)
|
||||||
|
}
|
||||||
|
assert.Regexp(t, tc.expect, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestAnonymizeDomain(t *testing.T) {
|
func TestAnonymizeDomain(t *testing.T) {
|
||||||
@@ -67,18 +115,36 @@ func TestAnonymizeDomain(t *testing.T) {
|
|||||||
`^anon-[a-zA-Z0-9]+\.domain$`,
|
`^anon-[a-zA-Z0-9]+\.domain$`,
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"Domain with Trailing Dot",
|
||||||
|
"example.com.",
|
||||||
|
`^anon-[a-zA-Z0-9]+\.domain.$`,
|
||||||
|
true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"Subdomain",
|
"Subdomain",
|
||||||
"sub.example.com",
|
"sub.example.com",
|
||||||
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
|
`^sub\.anon-[a-zA-Z0-9]+\.domain$`,
|
||||||
true,
|
true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"Subdomain with Trailing Dot",
|
||||||
|
"sub.example.com.",
|
||||||
|
`^sub\.anon-[a-zA-Z0-9]+\.domain.$`,
|
||||||
|
true,
|
||||||
|
},
|
||||||
{
|
{
|
||||||
"Protected Domain",
|
"Protected Domain",
|
||||||
"netbird.io",
|
"netbird.io",
|
||||||
`^netbird\.io$`,
|
`^netbird\.io$`,
|
||||||
false,
|
false,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
"Protected Domain with Trailing Dot",
|
||||||
|
"netbird.io.",
|
||||||
|
`^netbird\.io.$`,
|
||||||
|
false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
@@ -140,8 +206,16 @@ func TestAnonymizeSchemeURI(t *testing.T) {
|
|||||||
expect string
|
expect string
|
||||||
}{
|
}{
|
||||||
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
|
{"STUN URI in text", "Connection made via stun:example.com", `Connection made via stun:anon-[a-zA-Z0-9]+\.domain`},
|
||||||
|
{"STUNS URI in message", "Secure connection to stuns:example.com:443", `Secure connection to stuns:anon-[a-zA-Z0-9]+\.domain:443`},
|
||||||
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
|
{"TURN URI in log", "Failed attempt turn:some.example.com:3478?transport=tcp: retrying", `Failed attempt turn:some.anon-[a-zA-Z0-9]+\.domain:3478\?transport=tcp: retrying`},
|
||||||
|
{"TURNS URI in message", "Secure connection to turns:example.com:5349", `Secure connection to turns:anon-[a-zA-Z0-9]+\.domain:5349`},
|
||||||
|
{"HTTP URI in text", "Visit http://example.com for more", `Visit http://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||||
|
{"HTTPS URI in CAPS", "Visit HTTPS://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||||
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
{"HTTPS URI in message", "Visit https://example.com for more", `Visit https://anon-[a-zA-Z0-9]+\.domain for more`},
|
||||||
|
{"WS URI in log", "Connection established to ws://example.com:8080", `Connection established to ws://anon-[a-zA-Z0-9]+\.domain:8080`},
|
||||||
|
{"WSS URI in message", "Secure connection to wss://example.com", `Secure connection to wss://anon-[a-zA-Z0-9]+\.domain`},
|
||||||
|
{"Rel URI in text", "Relaying to rel://example.com", `Relaying to rel://anon-[a-zA-Z0-9]+\.domain`},
|
||||||
|
{"Rels URI in message", "Relaying to rels://example.com", `Relaying to rels://anon-[a-zA-Z0-9]+\.domain`},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range tests {
|
for _, tc := range tests {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package cmd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -12,6 +13,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errCloseConnection = "Failed to close connection: %v"
|
const errCloseConnection = "Failed to close connection: %v"
|
||||||
@@ -61,6 +63,15 @@ var forCmd = &cobra.Command{
|
|||||||
RunE: runForDuration,
|
RunE: runForDuration,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
var persistenceCmd = &cobra.Command{
|
||||||
|
Use: "persistence [on|off]",
|
||||||
|
Short: "Set network map memory persistence",
|
||||||
|
Long: `Configure whether the latest network map should persist in memory. When enabled, the last known network map will be kept in memory.`,
|
||||||
|
Example: " netbird debug persistence on",
|
||||||
|
Args: cobra.ExactArgs(1),
|
||||||
|
RunE: setNetworkMapPersistence,
|
||||||
|
}
|
||||||
|
|
||||||
func debugBundle(cmd *cobra.Command, _ []string) error {
|
func debugBundle(cmd *cobra.Command, _ []string) error {
|
||||||
conn, err := getClient(cmd)
|
conn, err := getClient(cmd)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -75,7 +86,7 @@ func debugBundle(cmd *cobra.Command, _ []string) error {
|
|||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
Status: getStatusOutput(cmd),
|
Status: getStatusOutput(cmd, anonymizeFlag),
|
||||||
SystemInfo: debugSystemInfoFlag,
|
SystemInfo: debugSystemInfoFlag,
|
||||||
})
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -171,6 +182,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
|
|
||||||
time.Sleep(1 * time.Second)
|
time.Sleep(1 * time.Second)
|
||||||
|
|
||||||
|
// Enable network map persistence before bringing the service up
|
||||||
|
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||||
|
Enabled: true,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to enable network map persistence: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
if _, err := client.Up(cmd.Context(), &proto.UpRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to up: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
@@ -179,7 +197,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
time.Sleep(3 * time.Second)
|
time.Sleep(3 * time.Second)
|
||||||
|
|
||||||
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
headerPostUp := fmt.Sprintf("----- Netbird post-up - Timestamp: %s", time.Now().Format(time.RFC3339))
|
||||||
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd))
|
statusOutput := fmt.Sprintf("%s\n%s", headerPostUp, getStatusOutput(cmd, anonymizeFlag))
|
||||||
|
|
||||||
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
if waitErr := waitForDurationOrCancel(cmd.Context(), duration, cmd); waitErr != nil {
|
||||||
return waitErr
|
return waitErr
|
||||||
@@ -189,7 +207,7 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
cmd.Println("Creating debug bundle...")
|
cmd.Println("Creating debug bundle...")
|
||||||
|
|
||||||
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
headerPreDown := fmt.Sprintf("----- Netbird pre-down - Timestamp: %s - Duration: %s", time.Now().Format(time.RFC3339), duration)
|
||||||
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd))
|
statusOutput = fmt.Sprintf("%s\n%s\n%s", statusOutput, headerPreDown, getStatusOutput(cmd, anonymizeFlag))
|
||||||
|
|
||||||
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
resp, err := client.DebugBundle(cmd.Context(), &proto.DebugBundleRequest{
|
||||||
Anonymize: anonymizeFlag,
|
Anonymize: anonymizeFlag,
|
||||||
@@ -200,6 +218,13 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to bundle debug: %v", status.Convert(err).Message())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Disable network map persistence after creating the debug bundle
|
||||||
|
if _, err := client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||||
|
Enabled: false,
|
||||||
|
}); err != nil {
|
||||||
|
return fmt.Errorf("failed to disable network map persistence: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
if stateWasDown {
|
if stateWasDown {
|
||||||
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
if _, err := client.Down(cmd.Context(), &proto.DownRequest{}); err != nil {
|
||||||
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
return fmt.Errorf("failed to down: %v", status.Convert(err).Message())
|
||||||
@@ -219,13 +244,43 @@ func runForDuration(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func getStatusOutput(cmd *cobra.Command) string {
|
func setNetworkMapPersistence(cmd *cobra.Command, args []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf(errCloseConnection, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
persistence := strings.ToLower(args[0])
|
||||||
|
if persistence != "on" && persistence != "off" {
|
||||||
|
return fmt.Errorf("invalid persistence value: %s. Use 'on' or 'off'", args[0])
|
||||||
|
}
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
_, err = client.SetNetworkMapPersistence(cmd.Context(), &proto.SetNetworkMapPersistenceRequest{
|
||||||
|
Enabled: persistence == "on",
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to set network map persistence: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("Network map persistence set to: %s\n", persistence)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getStatusOutput(cmd *cobra.Command, anon bool) string {
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
statusResp, err := getStatus(cmd.Context())
|
statusResp, err := getStatus(cmd.Context())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
cmd.PrintErrf("Failed to get status: %v\n", err)
|
cmd.PrintErrf("Failed to get status: %v\n", err)
|
||||||
} else {
|
} else {
|
||||||
statusOutputString = parseToFullDetailSummary(convertToStatusOutputOverview(statusResp))
|
statusOutputString = nbstatus.ParseToFullDetailSummary(
|
||||||
|
nbstatus.ConvertToStatusOutputOverview(statusResp, anon, "", nil, nil, nil),
|
||||||
|
)
|
||||||
}
|
}
|
||||||
return statusOutputString
|
return statusOutputString
|
||||||
}
|
}
|
||||||
|
|||||||
98
client/cmd/forwarding_rules.go
Normal file
98
client/cmd/forwarding_rules.go
Normal file
@@ -0,0 +1,98 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"sort"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var forwardingRulesCmd = &cobra.Command{
|
||||||
|
Use: "forwarding",
|
||||||
|
Short: "List forwarding rules",
|
||||||
|
Long: `Commands to list forwarding rules.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var forwardingRulesListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List forwarding rules",
|
||||||
|
Example: " netbird forwarding list",
|
||||||
|
Long: "Commands to list forwarding rules.",
|
||||||
|
RunE: listForwardingRules,
|
||||||
|
}
|
||||||
|
|
||||||
|
func listForwardingRules(cmd *cobra.Command, _ []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.ForwardingRules(cmd.Context(), &proto.EmptyRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.GetRules()) == 0 {
|
||||||
|
cmd.Println("No forwarding rules available.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
printForwardingRules(cmd, resp.GetRules())
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printForwardingRules(cmd *cobra.Command, rules []*proto.ForwardingRule) {
|
||||||
|
cmd.Println("Available forwarding rules:")
|
||||||
|
|
||||||
|
// Sort rules by translated address
|
||||||
|
sort.Slice(rules, func(i, j int) bool {
|
||||||
|
if rules[i].GetTranslatedAddress() != rules[j].GetTranslatedAddress() {
|
||||||
|
return rules[i].GetTranslatedAddress() < rules[j].GetTranslatedAddress()
|
||||||
|
}
|
||||||
|
if rules[i].GetProtocol() != rules[j].GetProtocol() {
|
||||||
|
return rules[i].GetProtocol() < rules[j].GetProtocol()
|
||||||
|
}
|
||||||
|
|
||||||
|
return getFirstPort(rules[i].GetDestinationPort()) < getFirstPort(rules[j].GetDestinationPort())
|
||||||
|
})
|
||||||
|
|
||||||
|
var lastIP string
|
||||||
|
for _, rule := range rules {
|
||||||
|
dPort := portToString(rule.GetDestinationPort())
|
||||||
|
tPort := portToString(rule.GetTranslatedPort())
|
||||||
|
if lastIP != rule.GetTranslatedAddress() {
|
||||||
|
lastIP = rule.GetTranslatedAddress()
|
||||||
|
cmd.Printf("\nTranslated peer: %s\n", rule.GetTranslatedHostname())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf(" Local %s/%s to %s:%s\n", rule.GetProtocol(), dPort, rule.GetTranslatedAddress(), tPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getFirstPort(portInfo *proto.PortInfo) int {
|
||||||
|
switch v := portInfo.PortSelection.(type) {
|
||||||
|
case *proto.PortInfo_Port:
|
||||||
|
return int(v.Port)
|
||||||
|
case *proto.PortInfo_Range_:
|
||||||
|
return int(v.Range.GetStart())
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func portToString(translatedPort *proto.PortInfo) string {
|
||||||
|
switch v := translatedPort.PortSelection.(type) {
|
||||||
|
case *proto.PortInfo_Port:
|
||||||
|
return fmt.Sprintf("%d", v.Port)
|
||||||
|
case *proto.PortInfo_Range_:
|
||||||
|
return fmt.Sprintf("%d-%d", v.Range.GetStart(), v.Range.GetEnd())
|
||||||
|
default:
|
||||||
|
return "No port specified"
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -85,11 +85,17 @@ var loginCmd = &cobra.Command{
|
|||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
|
||||||
|
var dnsLabelsReq []string
|
||||||
|
if dnsLabelsValidated != nil {
|
||||||
|
dnsLabelsReq = dnsLabelsValidated.ToSafeStringList()
|
||||||
|
}
|
||||||
|
|
||||||
loginRequest := proto.LoginRequest{
|
loginRequest := proto.LoginRequest{
|
||||||
SetupKey: providedSetupKey,
|
SetupKey: providedSetupKey,
|
||||||
ManagementUrl: managementURL,
|
ManagementUrl: managementURL,
|
||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
|
DnsLabels: dnsLabelsReq,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
|
|||||||
173
client/cmd/networks.go
Normal file
173
client/cmd/networks.go
Normal file
@@ -0,0 +1,173 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var appendFlag bool
|
||||||
|
|
||||||
|
var networksCMD = &cobra.Command{
|
||||||
|
Use: "networks",
|
||||||
|
Aliases: []string{"routes"},
|
||||||
|
Short: "Manage networks",
|
||||||
|
Long: `Commands to list, select, or deselect networks. Replaces the "routes" command.`,
|
||||||
|
}
|
||||||
|
|
||||||
|
var routesListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List networks",
|
||||||
|
Example: " netbird networks list",
|
||||||
|
Long: "List all available network routes.",
|
||||||
|
RunE: networksList,
|
||||||
|
}
|
||||||
|
|
||||||
|
var routesSelectCmd = &cobra.Command{
|
||||||
|
Use: "select network...|all",
|
||||||
|
Short: "Select network",
|
||||||
|
Long: "Select a list of networks by identifiers or 'all' to clear all selections and to accept all (including new) networks.\nDefault mode is replace, use -a to append to already selected networks.",
|
||||||
|
Example: " netbird networks select all\n netbird networks select route1 route2\n netbird routes select -a route3",
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: networksSelect,
|
||||||
|
}
|
||||||
|
|
||||||
|
var routesDeselectCmd = &cobra.Command{
|
||||||
|
Use: "deselect network...|all",
|
||||||
|
Short: "Deselect networks",
|
||||||
|
Long: "Deselect previously selected networks by identifiers or 'all' to disable accepting any networks.",
|
||||||
|
Example: " netbird networks deselect all\n netbird networks deselect route1 route2",
|
||||||
|
Args: cobra.MinimumNArgs(1),
|
||||||
|
RunE: networksDeselect,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current network selection instead of replacing")
|
||||||
|
}
|
||||||
|
|
||||||
|
func networksList(cmd *cobra.Command, _ []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.ListNetworks(cmd.Context(), &proto.ListNetworksRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list network: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(resp.Routes) == 0 {
|
||||||
|
cmd.Println("No networks available.")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
printNetworks(cmd, resp)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printNetworks(cmd *cobra.Command, resp *proto.ListNetworksResponse) {
|
||||||
|
cmd.Println("Available Networks:")
|
||||||
|
for _, route := range resp.Routes {
|
||||||
|
printNetwork(cmd, route)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func printNetwork(cmd *cobra.Command, route *proto.Network) {
|
||||||
|
selectedStatus := getSelectedStatus(route)
|
||||||
|
domains := route.GetDomains()
|
||||||
|
|
||||||
|
if len(domains) > 0 {
|
||||||
|
printDomainRoute(cmd, route, domains, selectedStatus)
|
||||||
|
} else {
|
||||||
|
printNetworkRoute(cmd, route, selectedStatus)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getSelectedStatus(route *proto.Network) string {
|
||||||
|
if route.GetSelected() {
|
||||||
|
return "Selected"
|
||||||
|
}
|
||||||
|
return "Not Selected"
|
||||||
|
}
|
||||||
|
|
||||||
|
func printDomainRoute(cmd *cobra.Command, route *proto.Network, domains []string, selectedStatus string) {
|
||||||
|
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
||||||
|
resolvedIPs := route.GetResolvedIPs()
|
||||||
|
|
||||||
|
if len(resolvedIPs) > 0 {
|
||||||
|
printResolvedIPs(cmd, domains, resolvedIPs)
|
||||||
|
} else {
|
||||||
|
cmd.Printf(" Resolved IPs: -\n")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func printNetworkRoute(cmd *cobra.Command, route *proto.Network, selectedStatus string) {
|
||||||
|
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetRange(), selectedStatus)
|
||||||
|
}
|
||||||
|
|
||||||
|
func printResolvedIPs(cmd *cobra.Command, _ []string, resolvedIPs map[string]*proto.IPList) {
|
||||||
|
cmd.Printf(" Resolved IPs:\n")
|
||||||
|
for resolvedDomain, ipList := range resolvedIPs {
|
||||||
|
cmd.Printf(" [%s]: %s\n", resolvedDomain, strings.Join(ipList.GetIps(), ", "))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func networksSelect(cmd *cobra.Command, args []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
req := &proto.SelectNetworksRequest{
|
||||||
|
NetworkIDs: args,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 1 && args[0] == "all" {
|
||||||
|
req.All = true
|
||||||
|
} else if appendFlag {
|
||||||
|
req.Append = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.SelectNetworks(cmd.Context(), req); err != nil {
|
||||||
|
return fmt.Errorf("failed to select networks: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("Networks selected successfully.")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func networksDeselect(cmd *cobra.Command, args []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
req := &proto.SelectNetworksRequest{
|
||||||
|
NetworkIDs: args,
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(args) == 1 && args[0] == "all" {
|
||||||
|
req.All = true
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := client.DeselectNetworks(cmd.Context(), req); err != nil {
|
||||||
|
return fmt.Errorf("failed to deselect networks: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Println("Networks deselected successfully.")
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
33
client/cmd/pprof.go
Normal file
33
client/cmd/pprof.go
Normal file
@@ -0,0 +1,33 @@
|
|||||||
|
//go:build pprof
|
||||||
|
// +build pprof
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/http"
|
||||||
|
_ "net/http/pprof"
|
||||||
|
"os"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
addr := pprofAddr()
|
||||||
|
go pprof(addr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func pprofAddr() string {
|
||||||
|
listenAddr := os.Getenv("NB_PPROF_ADDR")
|
||||||
|
if listenAddr == "" {
|
||||||
|
return "localhost:6969"
|
||||||
|
}
|
||||||
|
|
||||||
|
return listenAddr
|
||||||
|
}
|
||||||
|
|
||||||
|
func pprof(listenAddr string) {
|
||||||
|
log.Infof("listening pprof on: %s\n", listenAddr)
|
||||||
|
if err := http.ListenAndServe(listenAddr, nil); err != nil {
|
||||||
|
log.Fatalf("Failed to start pprof: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -38,6 +38,7 @@ const (
|
|||||||
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
extraIFaceBlackListFlag = "extra-iface-blacklist"
|
||||||
dnsRouteIntervalFlag = "dns-router-interval"
|
dnsRouteIntervalFlag = "dns-router-interval"
|
||||||
systemInfoFlag = "system-info"
|
systemInfoFlag = "system-info"
|
||||||
|
blockLANAccessFlag = "block-lan-access"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
@@ -73,6 +74,7 @@ var (
|
|||||||
anonymizeFlag bool
|
anonymizeFlag bool
|
||||||
debugSystemInfoFlag bool
|
debugSystemInfoFlag bool
|
||||||
dnsRouteInterval time.Duration
|
dnsRouteInterval time.Duration
|
||||||
|
blockLANAccess bool
|
||||||
|
|
||||||
rootCmd = &cobra.Command{
|
rootCmd = &cobra.Command{
|
||||||
Use: "netbird",
|
Use: "netbird",
|
||||||
@@ -142,19 +144,23 @@ func init() {
|
|||||||
rootCmd.AddCommand(loginCmd)
|
rootCmd.AddCommand(loginCmd)
|
||||||
rootCmd.AddCommand(versionCmd)
|
rootCmd.AddCommand(versionCmd)
|
||||||
rootCmd.AddCommand(sshCmd)
|
rootCmd.AddCommand(sshCmd)
|
||||||
rootCmd.AddCommand(routesCmd)
|
rootCmd.AddCommand(networksCMD)
|
||||||
|
rootCmd.AddCommand(forwardingRulesCmd)
|
||||||
rootCmd.AddCommand(debugCmd)
|
rootCmd.AddCommand(debugCmd)
|
||||||
|
|
||||||
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
serviceCmd.AddCommand(runCmd, startCmd, stopCmd, restartCmd) // service control commands are subcommands of service
|
||||||
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
serviceCmd.AddCommand(installCmd, uninstallCmd) // service installer commands are subcommands of service
|
||||||
|
|
||||||
routesCmd.AddCommand(routesListCmd)
|
networksCMD.AddCommand(routesListCmd)
|
||||||
routesCmd.AddCommand(routesSelectCmd, routesDeselectCmd)
|
networksCMD.AddCommand(routesSelectCmd, routesDeselectCmd)
|
||||||
|
|
||||||
|
forwardingRulesCmd.AddCommand(forwardingRulesListCmd)
|
||||||
|
|
||||||
debugCmd.AddCommand(debugBundleCmd)
|
debugCmd.AddCommand(debugBundleCmd)
|
||||||
debugCmd.AddCommand(logCmd)
|
debugCmd.AddCommand(logCmd)
|
||||||
logCmd.AddCommand(logLevelCmd)
|
logCmd.AddCommand(logLevelCmd)
|
||||||
debugCmd.AddCommand(forCmd)
|
debugCmd.AddCommand(forCmd)
|
||||||
|
debugCmd.AddCommand(persistenceCmd)
|
||||||
|
|
||||||
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
upCmd.PersistentFlags().StringSliceVar(&natExternalIPs, externalIPMapFlag, nil,
|
||||||
`Sets external IPs maps between local addresses and interfaces.`+
|
`Sets external IPs maps between local addresses and interfaces.`+
|
||||||
|
|||||||
@@ -1,174 +0,0 @@
|
|||||||
package cmd
|
|
||||||
|
|
||||||
import (
|
|
||||||
"fmt"
|
|
||||||
"strings"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
|
||||||
"google.golang.org/grpc/status"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
|
||||||
|
|
||||||
var appendFlag bool
|
|
||||||
|
|
||||||
var routesCmd = &cobra.Command{
|
|
||||||
Use: "routes",
|
|
||||||
Short: "Manage network routes",
|
|
||||||
Long: `Commands to list, select, or deselect network routes.`,
|
|
||||||
}
|
|
||||||
|
|
||||||
var routesListCmd = &cobra.Command{
|
|
||||||
Use: "list",
|
|
||||||
Aliases: []string{"ls"},
|
|
||||||
Short: "List routes",
|
|
||||||
Example: " netbird routes list",
|
|
||||||
Long: "List all available network routes.",
|
|
||||||
RunE: routesList,
|
|
||||||
}
|
|
||||||
|
|
||||||
var routesSelectCmd = &cobra.Command{
|
|
||||||
Use: "select route...|all",
|
|
||||||
Short: "Select routes",
|
|
||||||
Long: "Select a list of routes by identifiers or 'all' to clear all selections and to accept all (including new) routes.\nDefault mode is replace, use -a to append to already selected routes.",
|
|
||||||
Example: " netbird routes select all\n netbird routes select route1 route2\n netbird routes select -a route3",
|
|
||||||
Args: cobra.MinimumNArgs(1),
|
|
||||||
RunE: routesSelect,
|
|
||||||
}
|
|
||||||
|
|
||||||
var routesDeselectCmd = &cobra.Command{
|
|
||||||
Use: "deselect route...|all",
|
|
||||||
Short: "Deselect routes",
|
|
||||||
Long: "Deselect previously selected routes by identifiers or 'all' to disable accepting any routes.",
|
|
||||||
Example: " netbird routes deselect all\n netbird routes deselect route1 route2",
|
|
||||||
Args: cobra.MinimumNArgs(1),
|
|
||||||
RunE: routesDeselect,
|
|
||||||
}
|
|
||||||
|
|
||||||
func init() {
|
|
||||||
routesSelectCmd.PersistentFlags().BoolVarP(&appendFlag, "append", "a", false, "Append to current route selection instead of replacing")
|
|
||||||
}
|
|
||||||
|
|
||||||
func routesList(cmd *cobra.Command, _ []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
resp, err := client.ListRoutes(cmd.Context(), &proto.ListRoutesRequest{})
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("failed to list routes: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(resp.Routes) == 0 {
|
|
||||||
cmd.Println("No routes available.")
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
printRoutes(cmd, resp)
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func printRoutes(cmd *cobra.Command, resp *proto.ListRoutesResponse) {
|
|
||||||
cmd.Println("Available Routes:")
|
|
||||||
for _, route := range resp.Routes {
|
|
||||||
printRoute(cmd, route)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func printRoute(cmd *cobra.Command, route *proto.Route) {
|
|
||||||
selectedStatus := getSelectedStatus(route)
|
|
||||||
domains := route.GetDomains()
|
|
||||||
|
|
||||||
if len(domains) > 0 {
|
|
||||||
printDomainRoute(cmd, route, domains, selectedStatus)
|
|
||||||
} else {
|
|
||||||
printNetworkRoute(cmd, route, selectedStatus)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func getSelectedStatus(route *proto.Route) string {
|
|
||||||
if route.GetSelected() {
|
|
||||||
return "Selected"
|
|
||||||
}
|
|
||||||
return "Not Selected"
|
|
||||||
}
|
|
||||||
|
|
||||||
func printDomainRoute(cmd *cobra.Command, route *proto.Route, domains []string, selectedStatus string) {
|
|
||||||
cmd.Printf("\n - ID: %s\n Domains: %s\n Status: %s\n", route.GetID(), strings.Join(domains, ", "), selectedStatus)
|
|
||||||
resolvedIPs := route.GetResolvedIPs()
|
|
||||||
|
|
||||||
if len(resolvedIPs) > 0 {
|
|
||||||
printResolvedIPs(cmd, domains, resolvedIPs)
|
|
||||||
} else {
|
|
||||||
cmd.Printf(" Resolved IPs: -\n")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func printNetworkRoute(cmd *cobra.Command, route *proto.Route, selectedStatus string) {
|
|
||||||
cmd.Printf("\n - ID: %s\n Network: %s\n Status: %s\n", route.GetID(), route.GetNetwork(), selectedStatus)
|
|
||||||
}
|
|
||||||
|
|
||||||
func printResolvedIPs(cmd *cobra.Command, domains []string, resolvedIPs map[string]*proto.IPList) {
|
|
||||||
cmd.Printf(" Resolved IPs:\n")
|
|
||||||
for _, domain := range domains {
|
|
||||||
if ipList, exists := resolvedIPs[domain]; exists {
|
|
||||||
cmd.Printf(" [%s]: %s\n", domain, strings.Join(ipList.GetIps(), ", "))
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func routesSelect(cmd *cobra.Command, args []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
req := &proto.SelectRoutesRequest{
|
|
||||||
RouteIDs: args,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) == 1 && args[0] == "all" {
|
|
||||||
req.All = true
|
|
||||||
} else if appendFlag {
|
|
||||||
req.Append = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.SelectRoutes(cmd.Context(), req); err != nil {
|
|
||||||
return fmt.Errorf("failed to select routes: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Println("Routes selected successfully.")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func routesDeselect(cmd *cobra.Command, args []string) error {
|
|
||||||
conn, err := getClient(cmd)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
defer conn.Close()
|
|
||||||
|
|
||||||
client := proto.NewDaemonServiceClient(conn)
|
|
||||||
req := &proto.SelectRoutesRequest{
|
|
||||||
RouteIDs: args,
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(args) == 1 && args[0] == "all" {
|
|
||||||
req.All = true
|
|
||||||
}
|
|
||||||
|
|
||||||
if _, err := client.DeselectRoutes(cmd.Context(), req); err != nil {
|
|
||||||
return fmt.Errorf("failed to deselect routes: %v", status.Convert(err).Message())
|
|
||||||
}
|
|
||||||
|
|
||||||
cmd.Println("Routes deselected successfully.")
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
@@ -9,7 +9,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
@@ -73,7 +72,7 @@ var sshCmd = &cobra.Command{
|
|||||||
go func() {
|
go func() {
|
||||||
// blocking
|
// blocking
|
||||||
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
if err := runSSH(sshctx, host, []byte(config.SSHKey), cmd); err != nil {
|
||||||
log.Debug(err)
|
cmd.Printf("Error: %v\n", err)
|
||||||
os.Exit(1)
|
os.Exit(1)
|
||||||
}
|
}
|
||||||
cancel()
|
cancel()
|
||||||
|
|||||||
181
client/cmd/state.go
Normal file
181
client/cmd/state.go
Normal file
@@ -0,0 +1,181 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
allFlag bool
|
||||||
|
)
|
||||||
|
|
||||||
|
var stateCmd = &cobra.Command{
|
||||||
|
Use: "state",
|
||||||
|
Short: "Manage daemon state",
|
||||||
|
Long: "Provides commands for managing and inspecting the Netbird daemon state.",
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateListCmd = &cobra.Command{
|
||||||
|
Use: "list",
|
||||||
|
Aliases: []string{"ls"},
|
||||||
|
Short: "List all stored states",
|
||||||
|
Long: "Lists all registered states with their status and basic information.",
|
||||||
|
Example: " netbird state list",
|
||||||
|
RunE: stateList,
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateCleanCmd = &cobra.Command{
|
||||||
|
Use: "clean [state-name]",
|
||||||
|
Short: "Clean stored states",
|
||||||
|
Long: `Clean specific state or all states. The daemon must not be running.
|
||||||
|
This will perform cleanup operations and remove the state.`,
|
||||||
|
Example: ` netbird state clean dns_state
|
||||||
|
netbird state clean --all`,
|
||||||
|
RunE: stateClean,
|
||||||
|
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
// Check mutual exclusivity between --all flag and state-name argument
|
||||||
|
if allFlag && len(args) > 0 {
|
||||||
|
return fmt.Errorf("cannot specify both --all flag and state name")
|
||||||
|
}
|
||||||
|
if !allFlag && len(args) != 1 {
|
||||||
|
return fmt.Errorf("requires a state name argument or --all flag")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
var stateDeleteCmd = &cobra.Command{
|
||||||
|
Use: "delete [state-name]",
|
||||||
|
Short: "Delete stored states",
|
||||||
|
Long: `Delete specific state or all states from storage. The daemon must not be running.
|
||||||
|
This will remove the state without performing any cleanup operations.`,
|
||||||
|
Example: ` netbird state delete dns_state
|
||||||
|
netbird state delete --all`,
|
||||||
|
RunE: stateDelete,
|
||||||
|
PreRunE: func(cmd *cobra.Command, args []string) error {
|
||||||
|
// Check mutual exclusivity between --all flag and state-name argument
|
||||||
|
if allFlag && len(args) > 0 {
|
||||||
|
return fmt.Errorf("cannot specify both --all flag and state name")
|
||||||
|
}
|
||||||
|
if !allFlag && len(args) != 1 {
|
||||||
|
return fmt.Errorf("requires a state name argument or --all flag")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rootCmd.AddCommand(stateCmd)
|
||||||
|
stateCmd.AddCommand(stateListCmd, stateCleanCmd, stateDeleteCmd)
|
||||||
|
|
||||||
|
stateCleanCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Clean all states")
|
||||||
|
stateDeleteCmd.Flags().BoolVarP(&allFlag, "all", "a", false, "Delete all states")
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateList(cmd *cobra.Command, _ []string) error {
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf(errCloseConnection, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.ListStates(cmd.Context(), &proto.ListStatesRequest{})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to list states: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
cmd.Printf("\nStored states:\n\n")
|
||||||
|
for _, state := range resp.States {
|
||||||
|
cmd.Printf("- %s\n", state.Name)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateClean(cmd *cobra.Command, args []string) error {
|
||||||
|
var stateName string
|
||||||
|
if !allFlag {
|
||||||
|
stateName = args[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf(errCloseConnection, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.CleanState(cmd.Context(), &proto.CleanStateRequest{
|
||||||
|
StateName: stateName,
|
||||||
|
All: allFlag,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to clean state: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.CleanedStates == 0 {
|
||||||
|
cmd.Println("No states were cleaned")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if allFlag {
|
||||||
|
cmd.Printf("Successfully cleaned %d states\n", resp.CleanedStates)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("Successfully cleaned state %q\n", stateName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func stateDelete(cmd *cobra.Command, args []string) error {
|
||||||
|
var stateName string
|
||||||
|
if !allFlag {
|
||||||
|
stateName = args[0]
|
||||||
|
}
|
||||||
|
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
log.Errorf(errCloseConnection, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.DeleteState(cmd.Context(), &proto.DeleteStateRequest{
|
||||||
|
StateName: stateName,
|
||||||
|
All: allFlag,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("failed to delete state: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
if resp.DeletedStates == 0 {
|
||||||
|
cmd.Println("No states were deleted")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if allFlag {
|
||||||
|
cmd.Printf("Successfully deleted %d states\n", resp.DeletedStates)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("Successfully deleted state %q\n", stateName)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -2,105 +2,20 @@ package cmd
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"os"
|
|
||||||
"runtime"
|
|
||||||
"sort"
|
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/spf13/cobra"
|
"github.com/spf13/cobra"
|
||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
"gopkg.in/yaml.v3"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type peerStateDetailOutput struct {
|
|
||||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
|
||||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
|
||||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
|
||||||
Status string `json:"status" yaml:"status"`
|
|
||||||
LastStatusUpdate time.Time `json:"lastStatusUpdate" yaml:"lastStatusUpdate"`
|
|
||||||
ConnType string `json:"connectionType" yaml:"connectionType"`
|
|
||||||
IceCandidateType iceCandidateType `json:"iceCandidateType" yaml:"iceCandidateType"`
|
|
||||||
IceCandidateEndpoint iceCandidateType `json:"iceCandidateEndpoint" yaml:"iceCandidateEndpoint"`
|
|
||||||
RelayAddress string `json:"relayAddress" yaml:"relayAddress"`
|
|
||||||
LastWireguardHandshake time.Time `json:"lastWireguardHandshake" yaml:"lastWireguardHandshake"`
|
|
||||||
TransferReceived int64 `json:"transferReceived" yaml:"transferReceived"`
|
|
||||||
TransferSent int64 `json:"transferSent" yaml:"transferSent"`
|
|
||||||
Latency time.Duration `json:"latency" yaml:"latency"`
|
|
||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type peersStateOutput struct {
|
|
||||||
Total int `json:"total" yaml:"total"`
|
|
||||||
Connected int `json:"connected" yaml:"connected"`
|
|
||||||
Details []peerStateDetailOutput `json:"details" yaml:"details"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type signalStateOutput struct {
|
|
||||||
URL string `json:"url" yaml:"url"`
|
|
||||||
Connected bool `json:"connected" yaml:"connected"`
|
|
||||||
Error string `json:"error" yaml:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type managementStateOutput struct {
|
|
||||||
URL string `json:"url" yaml:"url"`
|
|
||||||
Connected bool `json:"connected" yaml:"connected"`
|
|
||||||
Error string `json:"error" yaml:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type relayStateOutputDetail struct {
|
|
||||||
URI string `json:"uri" yaml:"uri"`
|
|
||||||
Available bool `json:"available" yaml:"available"`
|
|
||||||
Error string `json:"error" yaml:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type relayStateOutput struct {
|
|
||||||
Total int `json:"total" yaml:"total"`
|
|
||||||
Available int `json:"available" yaml:"available"`
|
|
||||||
Details []relayStateOutputDetail `json:"details" yaml:"details"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type iceCandidateType struct {
|
|
||||||
Local string `json:"local" yaml:"local"`
|
|
||||||
Remote string `json:"remote" yaml:"remote"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type nsServerGroupStateOutput struct {
|
|
||||||
Servers []string `json:"servers" yaml:"servers"`
|
|
||||||
Domains []string `json:"domains" yaml:"domains"`
|
|
||||||
Enabled bool `json:"enabled" yaml:"enabled"`
|
|
||||||
Error string `json:"error" yaml:"error"`
|
|
||||||
}
|
|
||||||
|
|
||||||
type statusOutputOverview struct {
|
|
||||||
Peers peersStateOutput `json:"peers" yaml:"peers"`
|
|
||||||
CliVersion string `json:"cliVersion" yaml:"cliVersion"`
|
|
||||||
DaemonVersion string `json:"daemonVersion" yaml:"daemonVersion"`
|
|
||||||
ManagementState managementStateOutput `json:"management" yaml:"management"`
|
|
||||||
SignalState signalStateOutput `json:"signal" yaml:"signal"`
|
|
||||||
Relays relayStateOutput `json:"relays" yaml:"relays"`
|
|
||||||
IP string `json:"netbirdIp" yaml:"netbirdIp"`
|
|
||||||
PubKey string `json:"publicKey" yaml:"publicKey"`
|
|
||||||
KernelInterface bool `json:"usesKernelInterface" yaml:"usesKernelInterface"`
|
|
||||||
FQDN string `json:"fqdn" yaml:"fqdn"`
|
|
||||||
RosenpassEnabled bool `json:"quantumResistance" yaml:"quantumResistance"`
|
|
||||||
RosenpassPermissive bool `json:"quantumResistancePermissive" yaml:"quantumResistancePermissive"`
|
|
||||||
Routes []string `json:"routes" yaml:"routes"`
|
|
||||||
NSServerGroups []nsServerGroupStateOutput `json:"dnsServers" yaml:"dnsServers"`
|
|
||||||
}
|
|
||||||
|
|
||||||
var (
|
var (
|
||||||
detailFlag bool
|
detailFlag bool
|
||||||
ipv4Flag bool
|
ipv4Flag bool
|
||||||
@@ -171,18 +86,17 @@ func statusFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
outputInformationHolder := convertToStatusOutputOverview(resp)
|
var outputInformationHolder = nbstatus.ConvertToStatusOutputOverview(resp, anonymizeFlag, statusFilter, prefixNamesFilter, prefixNamesFilterMap, ipsFilterMap)
|
||||||
|
|
||||||
var statusOutputString string
|
var statusOutputString string
|
||||||
switch {
|
switch {
|
||||||
case detailFlag:
|
case detailFlag:
|
||||||
statusOutputString = parseToFullDetailSummary(outputInformationHolder)
|
statusOutputString = nbstatus.ParseToFullDetailSummary(outputInformationHolder)
|
||||||
case jsonFlag:
|
case jsonFlag:
|
||||||
statusOutputString, err = parseToJSON(outputInformationHolder)
|
statusOutputString, err = nbstatus.ParseToJSON(outputInformationHolder)
|
||||||
case yamlFlag:
|
case yamlFlag:
|
||||||
statusOutputString, err = parseToYAML(outputInformationHolder)
|
statusOutputString, err = nbstatus.ParseToYAML(outputInformationHolder)
|
||||||
default:
|
default:
|
||||||
statusOutputString = parseGeneralSummary(outputInformationHolder, false, false, false)
|
statusOutputString = nbstatus.ParseGeneralSummary(outputInformationHolder, false, false, false)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -212,7 +126,6 @@ func getStatus(ctx context.Context) (*proto.StatusResponse, error) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func parseFilters() error {
|
func parseFilters() error {
|
||||||
|
|
||||||
switch strings.ToLower(statusFilter) {
|
switch strings.ToLower(statusFilter) {
|
||||||
case "", "disconnected", "connected":
|
case "", "disconnected", "connected":
|
||||||
if strings.ToLower(statusFilter) != "" {
|
if strings.ToLower(statusFilter) != "" {
|
||||||
@@ -249,173 +162,6 @@ func enableDetailFlagWhenFilterFlag() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func convertToStatusOutputOverview(resp *proto.StatusResponse) statusOutputOverview {
|
|
||||||
pbFullStatus := resp.GetFullStatus()
|
|
||||||
|
|
||||||
managementState := pbFullStatus.GetManagementState()
|
|
||||||
managementOverview := managementStateOutput{
|
|
||||||
URL: managementState.GetURL(),
|
|
||||||
Connected: managementState.GetConnected(),
|
|
||||||
Error: managementState.Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
signalState := pbFullStatus.GetSignalState()
|
|
||||||
signalOverview := signalStateOutput{
|
|
||||||
URL: signalState.GetURL(),
|
|
||||||
Connected: signalState.GetConnected(),
|
|
||||||
Error: signalState.Error,
|
|
||||||
}
|
|
||||||
|
|
||||||
relayOverview := mapRelays(pbFullStatus.GetRelays())
|
|
||||||
peersOverview := mapPeers(resp.GetFullStatus().GetPeers())
|
|
||||||
|
|
||||||
overview := statusOutputOverview{
|
|
||||||
Peers: peersOverview,
|
|
||||||
CliVersion: version.NetbirdVersion(),
|
|
||||||
DaemonVersion: resp.GetDaemonVersion(),
|
|
||||||
ManagementState: managementOverview,
|
|
||||||
SignalState: signalOverview,
|
|
||||||
Relays: relayOverview,
|
|
||||||
IP: pbFullStatus.GetLocalPeerState().GetIP(),
|
|
||||||
PubKey: pbFullStatus.GetLocalPeerState().GetPubKey(),
|
|
||||||
KernelInterface: pbFullStatus.GetLocalPeerState().GetKernelInterface(),
|
|
||||||
FQDN: pbFullStatus.GetLocalPeerState().GetFqdn(),
|
|
||||||
RosenpassEnabled: pbFullStatus.GetLocalPeerState().GetRosenpassEnabled(),
|
|
||||||
RosenpassPermissive: pbFullStatus.GetLocalPeerState().GetRosenpassPermissive(),
|
|
||||||
Routes: pbFullStatus.GetLocalPeerState().GetRoutes(),
|
|
||||||
NSServerGroups: mapNSGroups(pbFullStatus.GetDnsServers()),
|
|
||||||
}
|
|
||||||
|
|
||||||
if anonymizeFlag {
|
|
||||||
anonymizer := anonymize.NewAnonymizer(anonymize.DefaultAddresses())
|
|
||||||
anonymizeOverview(anonymizer, &overview)
|
|
||||||
}
|
|
||||||
|
|
||||||
return overview
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapRelays(relays []*proto.RelayState) relayStateOutput {
|
|
||||||
var relayStateDetail []relayStateOutputDetail
|
|
||||||
|
|
||||||
var relaysAvailable int
|
|
||||||
for _, relay := range relays {
|
|
||||||
available := relay.GetAvailable()
|
|
||||||
relayStateDetail = append(relayStateDetail,
|
|
||||||
relayStateOutputDetail{
|
|
||||||
URI: relay.URI,
|
|
||||||
Available: available,
|
|
||||||
Error: relay.GetError(),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
|
|
||||||
if available {
|
|
||||||
relaysAvailable++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return relayStateOutput{
|
|
||||||
Total: len(relays),
|
|
||||||
Available: relaysAvailable,
|
|
||||||
Details: relayStateDetail,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapNSGroups(servers []*proto.NSGroupState) []nsServerGroupStateOutput {
|
|
||||||
mappedNSGroups := make([]nsServerGroupStateOutput, 0, len(servers))
|
|
||||||
for _, pbNsGroupServer := range servers {
|
|
||||||
mappedNSGroups = append(mappedNSGroups, nsServerGroupStateOutput{
|
|
||||||
Servers: pbNsGroupServer.GetServers(),
|
|
||||||
Domains: pbNsGroupServer.GetDomains(),
|
|
||||||
Enabled: pbNsGroupServer.GetEnabled(),
|
|
||||||
Error: pbNsGroupServer.GetError(),
|
|
||||||
})
|
|
||||||
}
|
|
||||||
return mappedNSGroups
|
|
||||||
}
|
|
||||||
|
|
||||||
func mapPeers(peers []*proto.PeerState) peersStateOutput {
|
|
||||||
var peersStateDetail []peerStateDetailOutput
|
|
||||||
peersConnected := 0
|
|
||||||
for _, pbPeerState := range peers {
|
|
||||||
localICE := ""
|
|
||||||
remoteICE := ""
|
|
||||||
localICEEndpoint := ""
|
|
||||||
remoteICEEndpoint := ""
|
|
||||||
relayServerAddress := ""
|
|
||||||
connType := ""
|
|
||||||
lastHandshake := time.Time{}
|
|
||||||
transferReceived := int64(0)
|
|
||||||
transferSent := int64(0)
|
|
||||||
|
|
||||||
isPeerConnected := pbPeerState.ConnStatus == peer.StatusConnected.String()
|
|
||||||
if skipDetailByFilters(pbPeerState, isPeerConnected) {
|
|
||||||
continue
|
|
||||||
}
|
|
||||||
if isPeerConnected {
|
|
||||||
peersConnected++
|
|
||||||
|
|
||||||
localICE = pbPeerState.GetLocalIceCandidateType()
|
|
||||||
remoteICE = pbPeerState.GetRemoteIceCandidateType()
|
|
||||||
localICEEndpoint = pbPeerState.GetLocalIceCandidateEndpoint()
|
|
||||||
remoteICEEndpoint = pbPeerState.GetRemoteIceCandidateEndpoint()
|
|
||||||
connType = "P2P"
|
|
||||||
if pbPeerState.Relayed {
|
|
||||||
connType = "Relayed"
|
|
||||||
}
|
|
||||||
relayServerAddress = pbPeerState.GetRelayAddress()
|
|
||||||
lastHandshake = pbPeerState.GetLastWireguardHandshake().AsTime().Local()
|
|
||||||
transferReceived = pbPeerState.GetBytesRx()
|
|
||||||
transferSent = pbPeerState.GetBytesTx()
|
|
||||||
}
|
|
||||||
|
|
||||||
timeLocal := pbPeerState.GetConnStatusUpdate().AsTime().Local()
|
|
||||||
peerState := peerStateDetailOutput{
|
|
||||||
IP: pbPeerState.GetIP(),
|
|
||||||
PubKey: pbPeerState.GetPubKey(),
|
|
||||||
Status: pbPeerState.GetConnStatus(),
|
|
||||||
LastStatusUpdate: timeLocal,
|
|
||||||
ConnType: connType,
|
|
||||||
IceCandidateType: iceCandidateType{
|
|
||||||
Local: localICE,
|
|
||||||
Remote: remoteICE,
|
|
||||||
},
|
|
||||||
IceCandidateEndpoint: iceCandidateType{
|
|
||||||
Local: localICEEndpoint,
|
|
||||||
Remote: remoteICEEndpoint,
|
|
||||||
},
|
|
||||||
RelayAddress: relayServerAddress,
|
|
||||||
FQDN: pbPeerState.GetFqdn(),
|
|
||||||
LastWireguardHandshake: lastHandshake,
|
|
||||||
TransferReceived: transferReceived,
|
|
||||||
TransferSent: transferSent,
|
|
||||||
Latency: pbPeerState.GetLatency().AsDuration(),
|
|
||||||
RosenpassEnabled: pbPeerState.GetRosenpassEnabled(),
|
|
||||||
Routes: pbPeerState.GetRoutes(),
|
|
||||||
}
|
|
||||||
|
|
||||||
peersStateDetail = append(peersStateDetail, peerState)
|
|
||||||
}
|
|
||||||
|
|
||||||
sortPeersByIP(peersStateDetail)
|
|
||||||
|
|
||||||
peersOverview := peersStateOutput{
|
|
||||||
Total: len(peersStateDetail),
|
|
||||||
Connected: peersConnected,
|
|
||||||
Details: peersStateDetail,
|
|
||||||
}
|
|
||||||
return peersOverview
|
|
||||||
}
|
|
||||||
|
|
||||||
func sortPeersByIP(peersStateDetail []peerStateDetailOutput) {
|
|
||||||
if len(peersStateDetail) > 0 {
|
|
||||||
sort.SliceStable(peersStateDetail, func(i, j int) bool {
|
|
||||||
iAddr, _ := netip.ParseAddr(peersStateDetail[i].IP)
|
|
||||||
jAddr, _ := netip.ParseAddr(peersStateDetail[j].IP)
|
|
||||||
return iAddr.Compare(jAddr) == -1
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseInterfaceIP(interfaceIP string) string {
|
func parseInterfaceIP(interfaceIP string) string {
|
||||||
ip, _, err := net.ParseCIDR(interfaceIP)
|
ip, _, err := net.ParseCIDR(interfaceIP)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -423,436 +169,3 @@ func parseInterfaceIP(interfaceIP string) string {
|
|||||||
}
|
}
|
||||||
return fmt.Sprintf("%s\n", ip)
|
return fmt.Sprintf("%s\n", ip)
|
||||||
}
|
}
|
||||||
|
|
||||||
func parseToJSON(overview statusOutputOverview) (string, error) {
|
|
||||||
jsonBytes, err := json.Marshal(overview)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("json marshal failed")
|
|
||||||
}
|
|
||||||
return string(jsonBytes), err
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseToYAML(overview statusOutputOverview) (string, error) {
|
|
||||||
yamlBytes, err := yaml.Marshal(overview)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("yaml marshal failed")
|
|
||||||
}
|
|
||||||
return string(yamlBytes), nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseGeneralSummary(overview statusOutputOverview, showURL bool, showRelays bool, showNameServers bool) string {
|
|
||||||
var managementConnString string
|
|
||||||
if overview.ManagementState.Connected {
|
|
||||||
managementConnString = "Connected"
|
|
||||||
if showURL {
|
|
||||||
managementConnString = fmt.Sprintf("%s to %s", managementConnString, overview.ManagementState.URL)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
managementConnString = "Disconnected"
|
|
||||||
if overview.ManagementState.Error != "" {
|
|
||||||
managementConnString = fmt.Sprintf("%s, reason: %s", managementConnString, overview.ManagementState.Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
var signalConnString string
|
|
||||||
if overview.SignalState.Connected {
|
|
||||||
signalConnString = "Connected"
|
|
||||||
if showURL {
|
|
||||||
signalConnString = fmt.Sprintf("%s to %s", signalConnString, overview.SignalState.URL)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
signalConnString = "Disconnected"
|
|
||||||
if overview.SignalState.Error != "" {
|
|
||||||
signalConnString = fmt.Sprintf("%s, reason: %s", signalConnString, overview.SignalState.Error)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
interfaceTypeString := "Userspace"
|
|
||||||
interfaceIP := overview.IP
|
|
||||||
if overview.KernelInterface {
|
|
||||||
interfaceTypeString = "Kernel"
|
|
||||||
} else if overview.IP == "" {
|
|
||||||
interfaceTypeString = "N/A"
|
|
||||||
interfaceIP = "N/A"
|
|
||||||
}
|
|
||||||
|
|
||||||
var relaysString string
|
|
||||||
if showRelays {
|
|
||||||
for _, relay := range overview.Relays.Details {
|
|
||||||
available := "Available"
|
|
||||||
reason := ""
|
|
||||||
if !relay.Available {
|
|
||||||
available = "Unavailable"
|
|
||||||
reason = fmt.Sprintf(", reason: %s", relay.Error)
|
|
||||||
}
|
|
||||||
relaysString += fmt.Sprintf("\n [%s] is %s%s", relay.URI, available, reason)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
relaysString = fmt.Sprintf("%d/%d Available", overview.Relays.Available, overview.Relays.Total)
|
|
||||||
}
|
|
||||||
|
|
||||||
routes := "-"
|
|
||||||
if len(overview.Routes) > 0 {
|
|
||||||
sort.Strings(overview.Routes)
|
|
||||||
routes = strings.Join(overview.Routes, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
var dnsServersString string
|
|
||||||
if showNameServers {
|
|
||||||
for _, nsServerGroup := range overview.NSServerGroups {
|
|
||||||
enabled := "Available"
|
|
||||||
if !nsServerGroup.Enabled {
|
|
||||||
enabled = "Unavailable"
|
|
||||||
}
|
|
||||||
errorString := ""
|
|
||||||
if nsServerGroup.Error != "" {
|
|
||||||
errorString = fmt.Sprintf(", reason: %s", nsServerGroup.Error)
|
|
||||||
errorString = strings.TrimSpace(errorString)
|
|
||||||
}
|
|
||||||
|
|
||||||
domainsString := strings.Join(nsServerGroup.Domains, ", ")
|
|
||||||
if domainsString == "" {
|
|
||||||
domainsString = "." // Show "." for the default zone
|
|
||||||
}
|
|
||||||
dnsServersString += fmt.Sprintf(
|
|
||||||
"\n [%s] for [%s] is %s%s",
|
|
||||||
strings.Join(nsServerGroup.Servers, ", "),
|
|
||||||
domainsString,
|
|
||||||
enabled,
|
|
||||||
errorString,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
dnsServersString = fmt.Sprintf("%d/%d Available", countEnabled(overview.NSServerGroups), len(overview.NSServerGroups))
|
|
||||||
}
|
|
||||||
|
|
||||||
rosenpassEnabledStatus := "false"
|
|
||||||
if overview.RosenpassEnabled {
|
|
||||||
rosenpassEnabledStatus = "true"
|
|
||||||
if overview.RosenpassPermissive {
|
|
||||||
rosenpassEnabledStatus = "true (permissive)" //nolint:gosec
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
peersCountString := fmt.Sprintf("%d/%d Connected", overview.Peers.Connected, overview.Peers.Total)
|
|
||||||
|
|
||||||
goos := runtime.GOOS
|
|
||||||
goarch := runtime.GOARCH
|
|
||||||
goarm := ""
|
|
||||||
if goarch == "arm" {
|
|
||||||
goarm = fmt.Sprintf(" (ARMv%s)", os.Getenv("GOARM"))
|
|
||||||
}
|
|
||||||
|
|
||||||
summary := fmt.Sprintf(
|
|
||||||
"OS: %s\n"+
|
|
||||||
"Daemon version: %s\n"+
|
|
||||||
"CLI version: %s\n"+
|
|
||||||
"Management: %s\n"+
|
|
||||||
"Signal: %s\n"+
|
|
||||||
"Relays: %s\n"+
|
|
||||||
"Nameservers: %s\n"+
|
|
||||||
"FQDN: %s\n"+
|
|
||||||
"NetBird IP: %s\n"+
|
|
||||||
"Interface type: %s\n"+
|
|
||||||
"Quantum resistance: %s\n"+
|
|
||||||
"Routes: %s\n"+
|
|
||||||
"Peers count: %s\n",
|
|
||||||
fmt.Sprintf("%s/%s%s", goos, goarch, goarm),
|
|
||||||
overview.DaemonVersion,
|
|
||||||
version.NetbirdVersion(),
|
|
||||||
managementConnString,
|
|
||||||
signalConnString,
|
|
||||||
relaysString,
|
|
||||||
dnsServersString,
|
|
||||||
overview.FQDN,
|
|
||||||
interfaceIP,
|
|
||||||
interfaceTypeString,
|
|
||||||
rosenpassEnabledStatus,
|
|
||||||
routes,
|
|
||||||
peersCountString,
|
|
||||||
)
|
|
||||||
return summary
|
|
||||||
}
|
|
||||||
|
|
||||||
func parseToFullDetailSummary(overview statusOutputOverview) string {
|
|
||||||
parsedPeersString := parsePeers(overview.Peers, overview.RosenpassEnabled, overview.RosenpassPermissive)
|
|
||||||
summary := parseGeneralSummary(overview, true, true, true)
|
|
||||||
|
|
||||||
return fmt.Sprintf(
|
|
||||||
"Peers detail:"+
|
|
||||||
"%s\n"+
|
|
||||||
"%s",
|
|
||||||
parsedPeersString,
|
|
||||||
summary,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func parsePeers(peers peersStateOutput, rosenpassEnabled, rosenpassPermissive bool) string {
|
|
||||||
var (
|
|
||||||
peersString = ""
|
|
||||||
)
|
|
||||||
|
|
||||||
for _, peerState := range peers.Details {
|
|
||||||
|
|
||||||
localICE := "-"
|
|
||||||
if peerState.IceCandidateType.Local != "" {
|
|
||||||
localICE = peerState.IceCandidateType.Local
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteICE := "-"
|
|
||||||
if peerState.IceCandidateType.Remote != "" {
|
|
||||||
remoteICE = peerState.IceCandidateType.Remote
|
|
||||||
}
|
|
||||||
|
|
||||||
localICEEndpoint := "-"
|
|
||||||
if peerState.IceCandidateEndpoint.Local != "" {
|
|
||||||
localICEEndpoint = peerState.IceCandidateEndpoint.Local
|
|
||||||
}
|
|
||||||
|
|
||||||
remoteICEEndpoint := "-"
|
|
||||||
if peerState.IceCandidateEndpoint.Remote != "" {
|
|
||||||
remoteICEEndpoint = peerState.IceCandidateEndpoint.Remote
|
|
||||||
}
|
|
||||||
|
|
||||||
rosenpassEnabledStatus := "false"
|
|
||||||
if rosenpassEnabled {
|
|
||||||
if peerState.RosenpassEnabled {
|
|
||||||
rosenpassEnabledStatus = "true"
|
|
||||||
} else {
|
|
||||||
if rosenpassPermissive {
|
|
||||||
rosenpassEnabledStatus = "false (remote didn't enable quantum resistance)"
|
|
||||||
} else {
|
|
||||||
rosenpassEnabledStatus = "false (connection won't work without a permissive mode)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
if peerState.RosenpassEnabled {
|
|
||||||
rosenpassEnabledStatus = "false (connection might not work without a remote permissive mode)"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
routes := "-"
|
|
||||||
if len(peerState.Routes) > 0 {
|
|
||||||
sort.Strings(peerState.Routes)
|
|
||||||
routes = strings.Join(peerState.Routes, ", ")
|
|
||||||
}
|
|
||||||
|
|
||||||
peerString := fmt.Sprintf(
|
|
||||||
"\n %s:\n"+
|
|
||||||
" NetBird IP: %s\n"+
|
|
||||||
" Public key: %s\n"+
|
|
||||||
" Status: %s\n"+
|
|
||||||
" -- detail --\n"+
|
|
||||||
" Connection type: %s\n"+
|
|
||||||
" ICE candidate (Local/Remote): %s/%s\n"+
|
|
||||||
" ICE candidate endpoints (Local/Remote): %s/%s\n"+
|
|
||||||
" Relay server address: %s\n"+
|
|
||||||
" Last connection update: %s\n"+
|
|
||||||
" Last WireGuard handshake: %s\n"+
|
|
||||||
" Transfer status (received/sent) %s/%s\n"+
|
|
||||||
" Quantum resistance: %s\n"+
|
|
||||||
" Routes: %s\n"+
|
|
||||||
" Latency: %s\n",
|
|
||||||
peerState.FQDN,
|
|
||||||
peerState.IP,
|
|
||||||
peerState.PubKey,
|
|
||||||
peerState.Status,
|
|
||||||
peerState.ConnType,
|
|
||||||
localICE,
|
|
||||||
remoteICE,
|
|
||||||
localICEEndpoint,
|
|
||||||
remoteICEEndpoint,
|
|
||||||
peerState.RelayAddress,
|
|
||||||
timeAgo(peerState.LastStatusUpdate),
|
|
||||||
timeAgo(peerState.LastWireguardHandshake),
|
|
||||||
toIEC(peerState.TransferReceived),
|
|
||||||
toIEC(peerState.TransferSent),
|
|
||||||
rosenpassEnabledStatus,
|
|
||||||
routes,
|
|
||||||
peerState.Latency.String(),
|
|
||||||
)
|
|
||||||
|
|
||||||
peersString += peerString
|
|
||||||
}
|
|
||||||
return peersString
|
|
||||||
}
|
|
||||||
|
|
||||||
func skipDetailByFilters(peerState *proto.PeerState, isConnected bool) bool {
|
|
||||||
statusEval := false
|
|
||||||
ipEval := false
|
|
||||||
nameEval := true
|
|
||||||
|
|
||||||
if statusFilter != "" {
|
|
||||||
lowerStatusFilter := strings.ToLower(statusFilter)
|
|
||||||
if lowerStatusFilter == "disconnected" && isConnected {
|
|
||||||
statusEval = true
|
|
||||||
} else if lowerStatusFilter == "connected" && !isConnected {
|
|
||||||
statusEval = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(ipsFilter) > 0 {
|
|
||||||
_, ok := ipsFilterMap[peerState.IP]
|
|
||||||
if !ok {
|
|
||||||
ipEval = true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if len(prefixNamesFilter) > 0 {
|
|
||||||
for prefixNameFilter := range prefixNamesFilterMap {
|
|
||||||
if strings.HasPrefix(peerState.Fqdn, prefixNameFilter) {
|
|
||||||
nameEval = false
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
nameEval = false
|
|
||||||
}
|
|
||||||
|
|
||||||
return statusEval || ipEval || nameEval
|
|
||||||
}
|
|
||||||
|
|
||||||
func toIEC(b int64) string {
|
|
||||||
const unit = 1024
|
|
||||||
if b < unit {
|
|
||||||
return fmt.Sprintf("%d B", b)
|
|
||||||
}
|
|
||||||
div, exp := int64(unit), 0
|
|
||||||
for n := b / unit; n >= unit; n /= unit {
|
|
||||||
div *= unit
|
|
||||||
exp++
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%.1f %ciB",
|
|
||||||
float64(b)/float64(div), "KMGTPE"[exp])
|
|
||||||
}
|
|
||||||
|
|
||||||
func countEnabled(dnsServers []nsServerGroupStateOutput) int {
|
|
||||||
count := 0
|
|
||||||
for _, server := range dnsServers {
|
|
||||||
if server.Enabled {
|
|
||||||
count++
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return count
|
|
||||||
}
|
|
||||||
|
|
||||||
// timeAgo returns a string representing the duration since the provided time in a human-readable format.
|
|
||||||
func timeAgo(t time.Time) string {
|
|
||||||
if t.IsZero() || t.Equal(time.Unix(0, 0)) {
|
|
||||||
return "-"
|
|
||||||
}
|
|
||||||
duration := time.Since(t)
|
|
||||||
switch {
|
|
||||||
case duration < time.Second:
|
|
||||||
return "Now"
|
|
||||||
case duration < time.Minute:
|
|
||||||
seconds := int(duration.Seconds())
|
|
||||||
if seconds == 1 {
|
|
||||||
return "1 second ago"
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d seconds ago", seconds)
|
|
||||||
case duration < time.Hour:
|
|
||||||
minutes := int(duration.Minutes())
|
|
||||||
seconds := int(duration.Seconds()) % 60
|
|
||||||
if minutes == 1 {
|
|
||||||
if seconds == 1 {
|
|
||||||
return "1 minute, 1 second ago"
|
|
||||||
} else if seconds > 0 {
|
|
||||||
return fmt.Sprintf("1 minute, %d seconds ago", seconds)
|
|
||||||
}
|
|
||||||
return "1 minute ago"
|
|
||||||
}
|
|
||||||
if seconds > 0 {
|
|
||||||
return fmt.Sprintf("%d minutes, %d seconds ago", minutes, seconds)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d minutes ago", minutes)
|
|
||||||
case duration < 24*time.Hour:
|
|
||||||
hours := int(duration.Hours())
|
|
||||||
minutes := int(duration.Minutes()) % 60
|
|
||||||
if hours == 1 {
|
|
||||||
if minutes == 1 {
|
|
||||||
return "1 hour, 1 minute ago"
|
|
||||||
} else if minutes > 0 {
|
|
||||||
return fmt.Sprintf("1 hour, %d minutes ago", minutes)
|
|
||||||
}
|
|
||||||
return "1 hour ago"
|
|
||||||
}
|
|
||||||
if minutes > 0 {
|
|
||||||
return fmt.Sprintf("%d hours, %d minutes ago", hours, minutes)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d hours ago", hours)
|
|
||||||
}
|
|
||||||
|
|
||||||
days := int(duration.Hours()) / 24
|
|
||||||
hours := int(duration.Hours()) % 24
|
|
||||||
if days == 1 {
|
|
||||||
if hours == 1 {
|
|
||||||
return "1 day, 1 hour ago"
|
|
||||||
} else if hours > 0 {
|
|
||||||
return fmt.Sprintf("1 day, %d hours ago", hours)
|
|
||||||
}
|
|
||||||
return "1 day ago"
|
|
||||||
}
|
|
||||||
if hours > 0 {
|
|
||||||
return fmt.Sprintf("%d days, %d hours ago", days, hours)
|
|
||||||
}
|
|
||||||
return fmt.Sprintf("%d days ago", days)
|
|
||||||
}
|
|
||||||
|
|
||||||
func anonymizePeerDetail(a *anonymize.Anonymizer, peer *peerStateDetailOutput) {
|
|
||||||
peer.FQDN = a.AnonymizeDomain(peer.FQDN)
|
|
||||||
if localIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Local); err == nil {
|
|
||||||
peer.IceCandidateEndpoint.Local = fmt.Sprintf("%s:%s", a.AnonymizeIPString(localIP), port)
|
|
||||||
}
|
|
||||||
if remoteIP, port, err := net.SplitHostPort(peer.IceCandidateEndpoint.Remote); err == nil {
|
|
||||||
peer.IceCandidateEndpoint.Remote = fmt.Sprintf("%s:%s", a.AnonymizeIPString(remoteIP), port)
|
|
||||||
}
|
|
||||||
|
|
||||||
peer.RelayAddress = a.AnonymizeURI(peer.RelayAddress)
|
|
||||||
|
|
||||||
for i, route := range peer.Routes {
|
|
||||||
peer.Routes[i] = a.AnonymizeIPString(route)
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, route := range peer.Routes {
|
|
||||||
peer.Routes[i] = a.AnonymizeRoute(route)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func anonymizeOverview(a *anonymize.Anonymizer, overview *statusOutputOverview) {
|
|
||||||
for i, peer := range overview.Peers.Details {
|
|
||||||
peer := peer
|
|
||||||
anonymizePeerDetail(a, &peer)
|
|
||||||
overview.Peers.Details[i] = peer
|
|
||||||
}
|
|
||||||
|
|
||||||
overview.ManagementState.URL = a.AnonymizeURI(overview.ManagementState.URL)
|
|
||||||
overview.ManagementState.Error = a.AnonymizeString(overview.ManagementState.Error)
|
|
||||||
overview.SignalState.URL = a.AnonymizeURI(overview.SignalState.URL)
|
|
||||||
overview.SignalState.Error = a.AnonymizeString(overview.SignalState.Error)
|
|
||||||
|
|
||||||
overview.IP = a.AnonymizeIPString(overview.IP)
|
|
||||||
for i, detail := range overview.Relays.Details {
|
|
||||||
detail.URI = a.AnonymizeURI(detail.URI)
|
|
||||||
detail.Error = a.AnonymizeString(detail.Error)
|
|
||||||
overview.Relays.Details[i] = detail
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, nsGroup := range overview.NSServerGroups {
|
|
||||||
for j, domain := range nsGroup.Domains {
|
|
||||||
overview.NSServerGroups[i].Domains[j] = a.AnonymizeDomain(domain)
|
|
||||||
}
|
|
||||||
for j, ns := range nsGroup.Servers {
|
|
||||||
host, port, err := net.SplitHostPort(ns)
|
|
||||||
if err == nil {
|
|
||||||
overview.NSServerGroups[i].Servers[j] = fmt.Sprintf("%s:%s", a.AnonymizeIPString(host), port)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for i, route := range overview.Routes {
|
|
||||||
overview.Routes[i] = a.AnonymizeRoute(route)
|
|
||||||
}
|
|
||||||
|
|
||||||
overview.FQDN = a.AnonymizeDomain(overview.FQDN)
|
|
||||||
}
|
|
||||||
|
|||||||
@@ -1,575 +1,11 @@
|
|||||||
package cmd
|
package cmd
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
|
||||||
"encoding/json"
|
|
||||||
"fmt"
|
|
||||||
"runtime"
|
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
|
||||||
|
|
||||||
"github.com/stretchr/testify/assert"
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
|
||||||
"google.golang.org/protobuf/types/known/durationpb"
|
|
||||||
"google.golang.org/protobuf/types/known/timestamppb"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
"github.com/netbirdio/netbird/version"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func init() {
|
|
||||||
loc, err := time.LoadLocation("UTC")
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
time.Local = loc
|
|
||||||
}
|
|
||||||
|
|
||||||
var resp = &proto.StatusResponse{
|
|
||||||
Status: "Connected",
|
|
||||||
FullStatus: &proto.FullStatus{
|
|
||||||
Peers: []*proto.PeerState{
|
|
||||||
{
|
|
||||||
IP: "192.168.178.101",
|
|
||||||
PubKey: "Pubkey1",
|
|
||||||
Fqdn: "peer-1.awesome-domain.com",
|
|
||||||
ConnStatus: "Connected",
|
|
||||||
ConnStatusUpdate: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 1, 0, time.UTC)),
|
|
||||||
Relayed: false,
|
|
||||||
LocalIceCandidateType: "",
|
|
||||||
RemoteIceCandidateType: "",
|
|
||||||
LocalIceCandidateEndpoint: "",
|
|
||||||
RemoteIceCandidateEndpoint: "",
|
|
||||||
LastWireguardHandshake: timestamppb.New(time.Date(2001, time.Month(1), 1, 1, 1, 2, 0, time.UTC)),
|
|
||||||
BytesRx: 200,
|
|
||||||
BytesTx: 100,
|
|
||||||
Routes: []string{
|
|
||||||
"10.1.0.0/24",
|
|
||||||
},
|
|
||||||
Latency: durationpb.New(time.Duration(10000000)),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.102",
|
|
||||||
PubKey: "Pubkey2",
|
|
||||||
Fqdn: "peer-2.awesome-domain.com",
|
|
||||||
ConnStatus: "Connected",
|
|
||||||
ConnStatusUpdate: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 2, 0, time.UTC)),
|
|
||||||
Relayed: true,
|
|
||||||
LocalIceCandidateType: "relay",
|
|
||||||
RemoteIceCandidateType: "prflx",
|
|
||||||
LocalIceCandidateEndpoint: "10.0.0.1:10001",
|
|
||||||
RemoteIceCandidateEndpoint: "10.0.10.1:10002",
|
|
||||||
LastWireguardHandshake: timestamppb.New(time.Date(2002, time.Month(2), 2, 2, 2, 3, 0, time.UTC)),
|
|
||||||
BytesRx: 2000,
|
|
||||||
BytesTx: 1000,
|
|
||||||
Latency: durationpb.New(time.Duration(10000000)),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
ManagementState: &proto.ManagementState{
|
|
||||||
URL: "my-awesome-management.com:443",
|
|
||||||
Connected: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
SignalState: &proto.SignalState{
|
|
||||||
URL: "my-awesome-signal.com:443",
|
|
||||||
Connected: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
Relays: []*proto.RelayState{
|
|
||||||
{
|
|
||||||
URI: "stun:my-awesome-stun.com:3478",
|
|
||||||
Available: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
|
||||||
Available: false,
|
|
||||||
Error: "context: deadline exceeded",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
LocalPeerState: &proto.LocalPeerState{
|
|
||||||
IP: "192.168.178.100/16",
|
|
||||||
PubKey: "Some-Pub-Key",
|
|
||||||
KernelInterface: true,
|
|
||||||
Fqdn: "some-localhost.awesome-domain.com",
|
|
||||||
Routes: []string{
|
|
||||||
"10.10.0.0/24",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
DnsServers: []*proto.NSGroupState{
|
|
||||||
{
|
|
||||||
Servers: []string{
|
|
||||||
"8.8.8.8:53",
|
|
||||||
},
|
|
||||||
Domains: nil,
|
|
||||||
Enabled: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Servers: []string{
|
|
||||||
"1.1.1.1:53",
|
|
||||||
"2.2.2.2:53",
|
|
||||||
},
|
|
||||||
Domains: []string{
|
|
||||||
"example.com",
|
|
||||||
"example.net",
|
|
||||||
},
|
|
||||||
Enabled: false,
|
|
||||||
Error: "timeout",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
DaemonVersion: "0.14.1",
|
|
||||||
}
|
|
||||||
|
|
||||||
var overview = statusOutputOverview{
|
|
||||||
Peers: peersStateOutput{
|
|
||||||
Total: 2,
|
|
||||||
Connected: 2,
|
|
||||||
Details: []peerStateDetailOutput{
|
|
||||||
{
|
|
||||||
IP: "192.168.178.101",
|
|
||||||
PubKey: "Pubkey1",
|
|
||||||
FQDN: "peer-1.awesome-domain.com",
|
|
||||||
Status: "Connected",
|
|
||||||
LastStatusUpdate: time.Date(2001, 1, 1, 1, 1, 1, 0, time.UTC),
|
|
||||||
ConnType: "P2P",
|
|
||||||
IceCandidateType: iceCandidateType{
|
|
||||||
Local: "",
|
|
||||||
Remote: "",
|
|
||||||
},
|
|
||||||
IceCandidateEndpoint: iceCandidateType{
|
|
||||||
Local: "",
|
|
||||||
Remote: "",
|
|
||||||
},
|
|
||||||
LastWireguardHandshake: time.Date(2001, 1, 1, 1, 1, 2, 0, time.UTC),
|
|
||||||
TransferReceived: 200,
|
|
||||||
TransferSent: 100,
|
|
||||||
Routes: []string{
|
|
||||||
"10.1.0.0/24",
|
|
||||||
},
|
|
||||||
Latency: time.Duration(10000000),
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.102",
|
|
||||||
PubKey: "Pubkey2",
|
|
||||||
FQDN: "peer-2.awesome-domain.com",
|
|
||||||
Status: "Connected",
|
|
||||||
LastStatusUpdate: time.Date(2002, 2, 2, 2, 2, 2, 0, time.UTC),
|
|
||||||
ConnType: "Relayed",
|
|
||||||
IceCandidateType: iceCandidateType{
|
|
||||||
Local: "relay",
|
|
||||||
Remote: "prflx",
|
|
||||||
},
|
|
||||||
IceCandidateEndpoint: iceCandidateType{
|
|
||||||
Local: "10.0.0.1:10001",
|
|
||||||
Remote: "10.0.10.1:10002",
|
|
||||||
},
|
|
||||||
LastWireguardHandshake: time.Date(2002, 2, 2, 2, 2, 3, 0, time.UTC),
|
|
||||||
TransferReceived: 2000,
|
|
||||||
TransferSent: 1000,
|
|
||||||
Latency: time.Duration(10000000),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
CliVersion: version.NetbirdVersion(),
|
|
||||||
DaemonVersion: "0.14.1",
|
|
||||||
ManagementState: managementStateOutput{
|
|
||||||
URL: "my-awesome-management.com:443",
|
|
||||||
Connected: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
SignalState: signalStateOutput{
|
|
||||||
URL: "my-awesome-signal.com:443",
|
|
||||||
Connected: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
Relays: relayStateOutput{
|
|
||||||
Total: 2,
|
|
||||||
Available: 1,
|
|
||||||
Details: []relayStateOutputDetail{
|
|
||||||
{
|
|
||||||
URI: "stun:my-awesome-stun.com:3478",
|
|
||||||
Available: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
URI: "turns:my-awesome-turn.com:443?transport=tcp",
|
|
||||||
Available: false,
|
|
||||||
Error: "context: deadline exceeded",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
},
|
|
||||||
IP: "192.168.178.100/16",
|
|
||||||
PubKey: "Some-Pub-Key",
|
|
||||||
KernelInterface: true,
|
|
||||||
FQDN: "some-localhost.awesome-domain.com",
|
|
||||||
NSServerGroups: []nsServerGroupStateOutput{
|
|
||||||
{
|
|
||||||
Servers: []string{
|
|
||||||
"8.8.8.8:53",
|
|
||||||
},
|
|
||||||
Domains: nil,
|
|
||||||
Enabled: true,
|
|
||||||
Error: "",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
Servers: []string{
|
|
||||||
"1.1.1.1:53",
|
|
||||||
"2.2.2.2:53",
|
|
||||||
},
|
|
||||||
Domains: []string{
|
|
||||||
"example.com",
|
|
||||||
"example.net",
|
|
||||||
},
|
|
||||||
Enabled: false,
|
|
||||||
Error: "timeout",
|
|
||||||
},
|
|
||||||
},
|
|
||||||
Routes: []string{
|
|
||||||
"10.10.0.0/24",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestConversionFromFullStatusToOutputOverview(t *testing.T) {
|
|
||||||
convertedResult := convertToStatusOutputOverview(resp)
|
|
||||||
|
|
||||||
assert.Equal(t, overview, convertedResult)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestSortingOfPeers(t *testing.T) {
|
|
||||||
peers := []peerStateDetailOutput{
|
|
||||||
{
|
|
||||||
IP: "192.168.178.104",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.102",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.101",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.105",
|
|
||||||
},
|
|
||||||
{
|
|
||||||
IP: "192.168.178.103",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
sortPeersByIP(peers)
|
|
||||||
|
|
||||||
assert.Equal(t, peers[3].IP, "192.168.178.104")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingToJSON(t *testing.T) {
|
|
||||||
jsonString, _ := parseToJSON(overview)
|
|
||||||
|
|
||||||
//@formatter:off
|
|
||||||
expectedJSONString := `
|
|
||||||
{
|
|
||||||
"peers": {
|
|
||||||
"total": 2,
|
|
||||||
"connected": 2,
|
|
||||||
"details": [
|
|
||||||
{
|
|
||||||
"fqdn": "peer-1.awesome-domain.com",
|
|
||||||
"netbirdIp": "192.168.178.101",
|
|
||||||
"publicKey": "Pubkey1",
|
|
||||||
"status": "Connected",
|
|
||||||
"lastStatusUpdate": "2001-01-01T01:01:01Z",
|
|
||||||
"connectionType": "P2P",
|
|
||||||
"iceCandidateType": {
|
|
||||||
"local": "",
|
|
||||||
"remote": ""
|
|
||||||
},
|
|
||||||
"iceCandidateEndpoint": {
|
|
||||||
"local": "",
|
|
||||||
"remote": ""
|
|
||||||
},
|
|
||||||
"relayAddress": "",
|
|
||||||
"lastWireguardHandshake": "2001-01-01T01:01:02Z",
|
|
||||||
"transferReceived": 200,
|
|
||||||
"transferSent": 100,
|
|
||||||
"latency": 10000000,
|
|
||||||
"quantumResistance": false,
|
|
||||||
"routes": [
|
|
||||||
"10.1.0.0/24"
|
|
||||||
]
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"fqdn": "peer-2.awesome-domain.com",
|
|
||||||
"netbirdIp": "192.168.178.102",
|
|
||||||
"publicKey": "Pubkey2",
|
|
||||||
"status": "Connected",
|
|
||||||
"lastStatusUpdate": "2002-02-02T02:02:02Z",
|
|
||||||
"connectionType": "Relayed",
|
|
||||||
"iceCandidateType": {
|
|
||||||
"local": "relay",
|
|
||||||
"remote": "prflx"
|
|
||||||
},
|
|
||||||
"iceCandidateEndpoint": {
|
|
||||||
"local": "10.0.0.1:10001",
|
|
||||||
"remote": "10.0.10.1:10002"
|
|
||||||
},
|
|
||||||
"relayAddress": "",
|
|
||||||
"lastWireguardHandshake": "2002-02-02T02:02:03Z",
|
|
||||||
"transferReceived": 2000,
|
|
||||||
"transferSent": 1000,
|
|
||||||
"latency": 10000000,
|
|
||||||
"quantumResistance": false,
|
|
||||||
"routes": null
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"cliVersion": "development",
|
|
||||||
"daemonVersion": "0.14.1",
|
|
||||||
"management": {
|
|
||||||
"url": "my-awesome-management.com:443",
|
|
||||||
"connected": true,
|
|
||||||
"error": ""
|
|
||||||
},
|
|
||||||
"signal": {
|
|
||||||
"url": "my-awesome-signal.com:443",
|
|
||||||
"connected": true,
|
|
||||||
"error": ""
|
|
||||||
},
|
|
||||||
"relays": {
|
|
||||||
"total": 2,
|
|
||||||
"available": 1,
|
|
||||||
"details": [
|
|
||||||
{
|
|
||||||
"uri": "stun:my-awesome-stun.com:3478",
|
|
||||||
"available": true,
|
|
||||||
"error": ""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"uri": "turns:my-awesome-turn.com:443?transport=tcp",
|
|
||||||
"available": false,
|
|
||||||
"error": "context: deadline exceeded"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
},
|
|
||||||
"netbirdIp": "192.168.178.100/16",
|
|
||||||
"publicKey": "Some-Pub-Key",
|
|
||||||
"usesKernelInterface": true,
|
|
||||||
"fqdn": "some-localhost.awesome-domain.com",
|
|
||||||
"quantumResistance": false,
|
|
||||||
"quantumResistancePermissive": false,
|
|
||||||
"routes": [
|
|
||||||
"10.10.0.0/24"
|
|
||||||
],
|
|
||||||
"dnsServers": [
|
|
||||||
{
|
|
||||||
"servers": [
|
|
||||||
"8.8.8.8:53"
|
|
||||||
],
|
|
||||||
"domains": null,
|
|
||||||
"enabled": true,
|
|
||||||
"error": ""
|
|
||||||
},
|
|
||||||
{
|
|
||||||
"servers": [
|
|
||||||
"1.1.1.1:53",
|
|
||||||
"2.2.2.2:53"
|
|
||||||
],
|
|
||||||
"domains": [
|
|
||||||
"example.com",
|
|
||||||
"example.net"
|
|
||||||
],
|
|
||||||
"enabled": false,
|
|
||||||
"error": "timeout"
|
|
||||||
}
|
|
||||||
]
|
|
||||||
}`
|
|
||||||
// @formatter:on
|
|
||||||
|
|
||||||
var expectedJSON bytes.Buffer
|
|
||||||
require.NoError(t, json.Compact(&expectedJSON, []byte(expectedJSONString)))
|
|
||||||
|
|
||||||
assert.Equal(t, expectedJSON.String(), jsonString)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingToYAML(t *testing.T) {
|
|
||||||
yaml, _ := parseToYAML(overview)
|
|
||||||
|
|
||||||
expectedYAML :=
|
|
||||||
`peers:
|
|
||||||
total: 2
|
|
||||||
connected: 2
|
|
||||||
details:
|
|
||||||
- fqdn: peer-1.awesome-domain.com
|
|
||||||
netbirdIp: 192.168.178.101
|
|
||||||
publicKey: Pubkey1
|
|
||||||
status: Connected
|
|
||||||
lastStatusUpdate: 2001-01-01T01:01:01Z
|
|
||||||
connectionType: P2P
|
|
||||||
iceCandidateType:
|
|
||||||
local: ""
|
|
||||||
remote: ""
|
|
||||||
iceCandidateEndpoint:
|
|
||||||
local: ""
|
|
||||||
remote: ""
|
|
||||||
relayAddress: ""
|
|
||||||
lastWireguardHandshake: 2001-01-01T01:01:02Z
|
|
||||||
transferReceived: 200
|
|
||||||
transferSent: 100
|
|
||||||
latency: 10ms
|
|
||||||
quantumResistance: false
|
|
||||||
routes:
|
|
||||||
- 10.1.0.0/24
|
|
||||||
- fqdn: peer-2.awesome-domain.com
|
|
||||||
netbirdIp: 192.168.178.102
|
|
||||||
publicKey: Pubkey2
|
|
||||||
status: Connected
|
|
||||||
lastStatusUpdate: 2002-02-02T02:02:02Z
|
|
||||||
connectionType: Relayed
|
|
||||||
iceCandidateType:
|
|
||||||
local: relay
|
|
||||||
remote: prflx
|
|
||||||
iceCandidateEndpoint:
|
|
||||||
local: 10.0.0.1:10001
|
|
||||||
remote: 10.0.10.1:10002
|
|
||||||
relayAddress: ""
|
|
||||||
lastWireguardHandshake: 2002-02-02T02:02:03Z
|
|
||||||
transferReceived: 2000
|
|
||||||
transferSent: 1000
|
|
||||||
latency: 10ms
|
|
||||||
quantumResistance: false
|
|
||||||
routes: []
|
|
||||||
cliVersion: development
|
|
||||||
daemonVersion: 0.14.1
|
|
||||||
management:
|
|
||||||
url: my-awesome-management.com:443
|
|
||||||
connected: true
|
|
||||||
error: ""
|
|
||||||
signal:
|
|
||||||
url: my-awesome-signal.com:443
|
|
||||||
connected: true
|
|
||||||
error: ""
|
|
||||||
relays:
|
|
||||||
total: 2
|
|
||||||
available: 1
|
|
||||||
details:
|
|
||||||
- uri: stun:my-awesome-stun.com:3478
|
|
||||||
available: true
|
|
||||||
error: ""
|
|
||||||
- uri: turns:my-awesome-turn.com:443?transport=tcp
|
|
||||||
available: false
|
|
||||||
error: 'context: deadline exceeded'
|
|
||||||
netbirdIp: 192.168.178.100/16
|
|
||||||
publicKey: Some-Pub-Key
|
|
||||||
usesKernelInterface: true
|
|
||||||
fqdn: some-localhost.awesome-domain.com
|
|
||||||
quantumResistance: false
|
|
||||||
quantumResistancePermissive: false
|
|
||||||
routes:
|
|
||||||
- 10.10.0.0/24
|
|
||||||
dnsServers:
|
|
||||||
- servers:
|
|
||||||
- 8.8.8.8:53
|
|
||||||
domains: []
|
|
||||||
enabled: true
|
|
||||||
error: ""
|
|
||||||
- servers:
|
|
||||||
- 1.1.1.1:53
|
|
||||||
- 2.2.2.2:53
|
|
||||||
domains:
|
|
||||||
- example.com
|
|
||||||
- example.net
|
|
||||||
enabled: false
|
|
||||||
error: timeout
|
|
||||||
`
|
|
||||||
|
|
||||||
assert.Equal(t, expectedYAML, yaml)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingToDetail(t *testing.T) {
|
|
||||||
// Calculate time ago based on the fixture dates
|
|
||||||
lastConnectionUpdate1 := timeAgo(overview.Peers.Details[0].LastStatusUpdate)
|
|
||||||
lastHandshake1 := timeAgo(overview.Peers.Details[0].LastWireguardHandshake)
|
|
||||||
lastConnectionUpdate2 := timeAgo(overview.Peers.Details[1].LastStatusUpdate)
|
|
||||||
lastHandshake2 := timeAgo(overview.Peers.Details[1].LastWireguardHandshake)
|
|
||||||
|
|
||||||
detail := parseToFullDetailSummary(overview)
|
|
||||||
|
|
||||||
expectedDetail := fmt.Sprintf(
|
|
||||||
`Peers detail:
|
|
||||||
peer-1.awesome-domain.com:
|
|
||||||
NetBird IP: 192.168.178.101
|
|
||||||
Public key: Pubkey1
|
|
||||||
Status: Connected
|
|
||||||
-- detail --
|
|
||||||
Connection type: P2P
|
|
||||||
ICE candidate (Local/Remote): -/-
|
|
||||||
ICE candidate endpoints (Local/Remote): -/-
|
|
||||||
Relay server address:
|
|
||||||
Last connection update: %s
|
|
||||||
Last WireGuard handshake: %s
|
|
||||||
Transfer status (received/sent) 200 B/100 B
|
|
||||||
Quantum resistance: false
|
|
||||||
Routes: 10.1.0.0/24
|
|
||||||
Latency: 10ms
|
|
||||||
|
|
||||||
peer-2.awesome-domain.com:
|
|
||||||
NetBird IP: 192.168.178.102
|
|
||||||
Public key: Pubkey2
|
|
||||||
Status: Connected
|
|
||||||
-- detail --
|
|
||||||
Connection type: Relayed
|
|
||||||
ICE candidate (Local/Remote): relay/prflx
|
|
||||||
ICE candidate endpoints (Local/Remote): 10.0.0.1:10001/10.0.10.1:10002
|
|
||||||
Relay server address:
|
|
||||||
Last connection update: %s
|
|
||||||
Last WireGuard handshake: %s
|
|
||||||
Transfer status (received/sent) 2.0 KiB/1000 B
|
|
||||||
Quantum resistance: false
|
|
||||||
Routes: -
|
|
||||||
Latency: 10ms
|
|
||||||
|
|
||||||
OS: %s/%s
|
|
||||||
Daemon version: 0.14.1
|
|
||||||
CLI version: %s
|
|
||||||
Management: Connected to my-awesome-management.com:443
|
|
||||||
Signal: Connected to my-awesome-signal.com:443
|
|
||||||
Relays:
|
|
||||||
[stun:my-awesome-stun.com:3478] is Available
|
|
||||||
[turns:my-awesome-turn.com:443?transport=tcp] is Unavailable, reason: context: deadline exceeded
|
|
||||||
Nameservers:
|
|
||||||
[8.8.8.8:53] for [.] is Available
|
|
||||||
[1.1.1.1:53, 2.2.2.2:53] for [example.com, example.net] is Unavailable, reason: timeout
|
|
||||||
FQDN: some-localhost.awesome-domain.com
|
|
||||||
NetBird IP: 192.168.178.100/16
|
|
||||||
Interface type: Kernel
|
|
||||||
Quantum resistance: false
|
|
||||||
Routes: 10.10.0.0/24
|
|
||||||
Peers count: 2/2 Connected
|
|
||||||
`, lastConnectionUpdate1, lastHandshake1, lastConnectionUpdate2, lastHandshake2, runtime.GOOS, runtime.GOARCH, overview.CliVersion)
|
|
||||||
|
|
||||||
assert.Equal(t, expectedDetail, detail)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingToShortVersion(t *testing.T) {
|
|
||||||
shortVersion := parseGeneralSummary(overview, false, false, false)
|
|
||||||
|
|
||||||
expectedString := fmt.Sprintf("OS: %s/%s", runtime.GOOS, runtime.GOARCH) + `
|
|
||||||
Daemon version: 0.14.1
|
|
||||||
CLI version: development
|
|
||||||
Management: Connected
|
|
||||||
Signal: Connected
|
|
||||||
Relays: 1/2 Available
|
|
||||||
Nameservers: 1/2 Available
|
|
||||||
FQDN: some-localhost.awesome-domain.com
|
|
||||||
NetBird IP: 192.168.178.100/16
|
|
||||||
Interface type: Kernel
|
|
||||||
Quantum resistance: false
|
|
||||||
Routes: 10.10.0.0/24
|
|
||||||
Peers count: 2/2 Connected
|
|
||||||
`
|
|
||||||
|
|
||||||
assert.Equal(t, expectedString, shortVersion)
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestParsingOfIP(t *testing.T) {
|
func TestParsingOfIP(t *testing.T) {
|
||||||
InterfaceIP := "192.168.178.123/16"
|
InterfaceIP := "192.168.178.123/16"
|
||||||
|
|
||||||
@@ -577,31 +13,3 @@ func TestParsingOfIP(t *testing.T) {
|
|||||||
|
|
||||||
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
assert.Equal(t, "192.168.178.123\n", parsedIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestTimeAgo(t *testing.T) {
|
|
||||||
now := time.Now()
|
|
||||||
|
|
||||||
cases := []struct {
|
|
||||||
name string
|
|
||||||
input time.Time
|
|
||||||
expected string
|
|
||||||
}{
|
|
||||||
{"Now", now, "Now"},
|
|
||||||
{"Seconds ago", now.Add(-10 * time.Second), "10 seconds ago"},
|
|
||||||
{"One minute ago", now.Add(-1 * time.Minute), "1 minute ago"},
|
|
||||||
{"Minutes and seconds ago", now.Add(-(1*time.Minute + 30*time.Second)), "1 minute, 30 seconds ago"},
|
|
||||||
{"One hour ago", now.Add(-1 * time.Hour), "1 hour ago"},
|
|
||||||
{"Hours and minutes ago", now.Add(-(2*time.Hour + 15*time.Minute)), "2 hours, 15 minutes ago"},
|
|
||||||
{"One day ago", now.Add(-24 * time.Hour), "1 day ago"},
|
|
||||||
{"Multiple days ago", now.Add(-(72*time.Hour + 20*time.Minute)), "3 days ago"},
|
|
||||||
{"Zero time", time.Time{}, "-"},
|
|
||||||
{"Unix zero time", time.Unix(0, 0), "-"},
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, tc := range cases {
|
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
|
||||||
result := timeAgo(tc.input)
|
|
||||||
assert.Equal(t, tc.expected, result, "Failed %s", tc.name)
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|||||||
31
client/cmd/system.go
Normal file
31
client/cmd/system.go
Normal file
@@ -0,0 +1,31 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
// Flag constants for system configuration
|
||||||
|
const (
|
||||||
|
disableClientRoutesFlag = "disable-client-routes"
|
||||||
|
disableServerRoutesFlag = "disable-server-routes"
|
||||||
|
disableDNSFlag = "disable-dns"
|
||||||
|
disableFirewallFlag = "disable-firewall"
|
||||||
|
)
|
||||||
|
|
||||||
|
var (
|
||||||
|
disableClientRoutes bool
|
||||||
|
disableServerRoutes bool
|
||||||
|
disableDNS bool
|
||||||
|
disableFirewall bool
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// Add system flags to upCmd
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableClientRoutes, disableClientRoutesFlag, false,
|
||||||
|
"Disable client routes. If enabled, the client won't process client routes received from the management service.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableServerRoutes, disableServerRoutesFlag, false,
|
||||||
|
"Disable server routes. If enabled, the client won't act as a router for server routes received from the management service.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableDNS, disableDNSFlag, false,
|
||||||
|
"Disable DNS. If enabled, the client won't configure DNS settings.")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&disableFirewall, disableFirewallFlag, false,
|
||||||
|
"Disable firewall configuration. If enabled, the client won't modify firewall rules.")
|
||||||
|
}
|
||||||
@@ -10,6 +10,9 @@ import (
|
|||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
@@ -71,7 +74,7 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
s := grpc.NewServer()
|
s := grpc.NewServer()
|
||||||
store, cleanUp, err := mgmt.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
store, cleanUp, err := store.NewTestStoreFromSQL(context.Background(), testFile, t.TempDir())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -87,13 +90,13 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics)
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settings.NewManagerMock())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, peersUpdateManager, secretsManager, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManagerMock(), peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
137
client/cmd/trace.go
Normal file
137
client/cmd/trace.go
Normal file
@@ -0,0 +1,137 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"math/rand"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/spf13/cobra"
|
||||||
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
|
)
|
||||||
|
|
||||||
|
var traceCmd = &cobra.Command{
|
||||||
|
Use: "trace <direction> <source-ip> <dest-ip>",
|
||||||
|
Short: "Trace a packet through the firewall",
|
||||||
|
Example: `
|
||||||
|
netbird debug trace in 192.168.1.10 10.10.0.2 -p tcp --sport 12345 --dport 443 --syn --ack
|
||||||
|
netbird debug trace out 10.10.0.1 8.8.8.8 -p udp --dport 53
|
||||||
|
netbird debug trace in 10.10.0.2 10.10.0.1 -p icmp --type 8 --code 0
|
||||||
|
netbird debug trace in 100.64.1.1 self -p tcp --dport 80`,
|
||||||
|
Args: cobra.ExactArgs(3),
|
||||||
|
RunE: tracePacket,
|
||||||
|
}
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
debugCmd.AddCommand(traceCmd)
|
||||||
|
|
||||||
|
traceCmd.Flags().StringP("protocol", "p", "tcp", "Protocol (tcp/udp/icmp)")
|
||||||
|
traceCmd.Flags().Uint16("sport", 0, "Source port")
|
||||||
|
traceCmd.Flags().Uint16("dport", 0, "Destination port")
|
||||||
|
traceCmd.Flags().Uint8("icmp-type", 0, "ICMP type")
|
||||||
|
traceCmd.Flags().Uint8("icmp-code", 0, "ICMP code")
|
||||||
|
traceCmd.Flags().Bool("syn", false, "TCP SYN flag")
|
||||||
|
traceCmd.Flags().Bool("ack", false, "TCP ACK flag")
|
||||||
|
traceCmd.Flags().Bool("fin", false, "TCP FIN flag")
|
||||||
|
traceCmd.Flags().Bool("rst", false, "TCP RST flag")
|
||||||
|
traceCmd.Flags().Bool("psh", false, "TCP PSH flag")
|
||||||
|
traceCmd.Flags().Bool("urg", false, "TCP URG flag")
|
||||||
|
}
|
||||||
|
|
||||||
|
func tracePacket(cmd *cobra.Command, args []string) error {
|
||||||
|
direction := strings.ToLower(args[0])
|
||||||
|
if direction != "in" && direction != "out" {
|
||||||
|
return fmt.Errorf("invalid direction: use 'in' or 'out'")
|
||||||
|
}
|
||||||
|
|
||||||
|
protocol := cmd.Flag("protocol").Value.String()
|
||||||
|
if protocol != "tcp" && protocol != "udp" && protocol != "icmp" {
|
||||||
|
return fmt.Errorf("invalid protocol: use tcp/udp/icmp")
|
||||||
|
}
|
||||||
|
|
||||||
|
sport, err := cmd.Flags().GetUint16("sport")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid source port: %v", err)
|
||||||
|
}
|
||||||
|
dport, err := cmd.Flags().GetUint16("dport")
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("invalid destination port: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For TCP/UDP, generate random ephemeral port (49152-65535) if not specified
|
||||||
|
if protocol != "icmp" {
|
||||||
|
if sport == 0 {
|
||||||
|
sport = uint16(rand.Intn(16383) + 49152)
|
||||||
|
}
|
||||||
|
if dport == 0 {
|
||||||
|
dport = uint16(rand.Intn(16383) + 49152)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
var tcpFlags *proto.TCPFlags
|
||||||
|
if protocol == "tcp" {
|
||||||
|
syn, _ := cmd.Flags().GetBool("syn")
|
||||||
|
ack, _ := cmd.Flags().GetBool("ack")
|
||||||
|
fin, _ := cmd.Flags().GetBool("fin")
|
||||||
|
rst, _ := cmd.Flags().GetBool("rst")
|
||||||
|
psh, _ := cmd.Flags().GetBool("psh")
|
||||||
|
urg, _ := cmd.Flags().GetBool("urg")
|
||||||
|
|
||||||
|
tcpFlags = &proto.TCPFlags{
|
||||||
|
Syn: syn,
|
||||||
|
Ack: ack,
|
||||||
|
Fin: fin,
|
||||||
|
Rst: rst,
|
||||||
|
Psh: psh,
|
||||||
|
Urg: urg,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
icmpType, _ := cmd.Flags().GetUint32("icmp-type")
|
||||||
|
icmpCode, _ := cmd.Flags().GetUint32("icmp-code")
|
||||||
|
|
||||||
|
conn, err := getClient(cmd)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
defer conn.Close()
|
||||||
|
|
||||||
|
client := proto.NewDaemonServiceClient(conn)
|
||||||
|
resp, err := client.TracePacket(cmd.Context(), &proto.TracePacketRequest{
|
||||||
|
SourceIp: args[1],
|
||||||
|
DestinationIp: args[2],
|
||||||
|
Protocol: protocol,
|
||||||
|
SourcePort: uint32(sport),
|
||||||
|
DestinationPort: uint32(dport),
|
||||||
|
Direction: direction,
|
||||||
|
TcpFlags: tcpFlags,
|
||||||
|
IcmpType: &icmpType,
|
||||||
|
IcmpCode: &icmpCode,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("trace failed: %v", status.Convert(err).Message())
|
||||||
|
}
|
||||||
|
|
||||||
|
printTrace(cmd, args[1], args[2], protocol, sport, dport, resp)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func printTrace(cmd *cobra.Command, src, dst, proto string, sport, dport uint16, resp *proto.TracePacketResponse) {
|
||||||
|
cmd.Printf("Packet trace %s:%d -> %s:%d (%s)\n\n", src, sport, dst, dport, strings.ToUpper(proto))
|
||||||
|
|
||||||
|
for _, stage := range resp.Stages {
|
||||||
|
if stage.ForwardingDetails != nil {
|
||||||
|
cmd.Printf("%s: %s [%s]\n", stage.Name, stage.Message, *stage.ForwardingDetails)
|
||||||
|
} else {
|
||||||
|
cmd.Printf("%s: %s\n", stage.Name, stage.Message)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
disposition := map[bool]string{
|
||||||
|
true: "\033[32mALLOWED\033[0m", // Green
|
||||||
|
false: "\033[31mDENIED\033[0m", // Red
|
||||||
|
}[resp.FinalDisposition]
|
||||||
|
|
||||||
|
cmd.Printf("\nFinal disposition: %s\n", disposition)
|
||||||
|
}
|
||||||
@@ -20,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/system"
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,9 +30,16 @@ const (
|
|||||||
interfaceInputType
|
interfaceInputType
|
||||||
)
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
dnsLabelsFlag = "extra-dns-labels"
|
||||||
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
foregroundMode bool
|
foregroundMode bool
|
||||||
upCmd = &cobra.Command{
|
dnsLabels []string
|
||||||
|
dnsLabelsValidated domain.List
|
||||||
|
|
||||||
|
upCmd = &cobra.Command{
|
||||||
Use: "up",
|
Use: "up",
|
||||||
Short: "install, login and start Netbird client",
|
Short: "install, login and start Netbird client",
|
||||||
RunE: upFunc,
|
RunE: upFunc,
|
||||||
@@ -48,6 +56,15 @@ func init() {
|
|||||||
)
|
)
|
||||||
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
upCmd.PersistentFlags().StringSliceVar(&extraIFaceBlackList, extraIFaceBlackListFlag, nil, "Extra list of default interfaces to ignore for listening")
|
||||||
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
upCmd.PersistentFlags().DurationVar(&dnsRouteInterval, dnsRouteIntervalFlag, time.Minute, "DNS route update interval")
|
||||||
|
upCmd.PersistentFlags().BoolVar(&blockLANAccess, blockLANAccessFlag, false, "Block access to local networks (LAN) when using this peer as a router or exit node")
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().StringSliceVar(&dnsLabels, dnsLabelsFlag, nil,
|
||||||
|
`Sets DNS labels`+
|
||||||
|
`You can specify a comma-separated list of up to 32 labels. `+
|
||||||
|
`An empty string "" clears the previous configuration. `+
|
||||||
|
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||||
|
`or --extra-dns-labels ""`,
|
||||||
|
)
|
||||||
}
|
}
|
||||||
|
|
||||||
func upFunc(cmd *cobra.Command, args []string) error {
|
func upFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -66,6 +83,11 @@ func upFunc(cmd *cobra.Command, args []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
dnsLabelsValidated, err = validateDnsLabels(dnsLabels)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
ctx := internal.CtxInitState(cmd.Context())
|
ctx := internal.CtxInitState(cmd.Context())
|
||||||
|
|
||||||
if hostName != "" {
|
if hostName != "" {
|
||||||
@@ -97,6 +119,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
NATExternalIPs: natExternalIPs,
|
NATExternalIPs: natExternalIPs,
|
||||||
CustomDNSAddress: customDNSAddressConverted,
|
CustomDNSAddress: customDNSAddressConverted,
|
||||||
ExtraIFaceBlackList: extraIFaceBlackList,
|
ExtraIFaceBlackList: extraIFaceBlackList,
|
||||||
|
DNSLabels: dnsLabelsValidated,
|
||||||
}
|
}
|
||||||
|
|
||||||
if cmd.Flag(enableRosenpassFlag).Changed {
|
if cmd.Flag(enableRosenpassFlag).Changed {
|
||||||
@@ -147,6 +170,23 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
ic.DNSRouteInterval = &dnsRouteInterval
|
ic.DNSRouteInterval = &dnsRouteInterval
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||||
|
ic.DisableClientRoutes = &disableClientRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||||
|
ic.DisableServerRoutes = &disableServerRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableDNSFlag).Changed {
|
||||||
|
ic.DisableDNS = &disableDNS
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableFirewallFlag).Changed {
|
||||||
|
ic.DisableFirewall = &disableFirewall
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||||
|
ic.BlockLANAccess = &blockLANAccess
|
||||||
|
}
|
||||||
|
|
||||||
providedSetupKey, err := getSetupKey()
|
providedSetupKey, err := getSetupKey()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -172,7 +212,7 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
r.GetFullStatus()
|
r.GetFullStatus()
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||||
return connectClient.Run()
|
return connectClient.Run(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
||||||
@@ -222,6 +262,8 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
IsLinuxDesktopClient: isLinuxRunningDesktop(),
|
||||||
Hostname: hostName,
|
Hostname: hostName,
|
||||||
ExtraIFaceBlacklist: extraIFaceBlackList,
|
ExtraIFaceBlacklist: extraIFaceBlackList,
|
||||||
|
DnsLabels: dnsLabels,
|
||||||
|
CleanDNSLabels: dnsLabels != nil && len(dnsLabels) == 0,
|
||||||
}
|
}
|
||||||
|
|
||||||
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
if rootCmd.PersistentFlags().Changed(preSharedKeyFlag) {
|
||||||
@@ -264,6 +306,23 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
loginRequest.DnsRouteInterval = durationpb.New(dnsRouteInterval)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(disableClientRoutesFlag).Changed {
|
||||||
|
loginRequest.DisableClientRoutes = &disableClientRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableServerRoutesFlag).Changed {
|
||||||
|
loginRequest.DisableServerRoutes = &disableServerRoutes
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableDNSFlag).Changed {
|
||||||
|
loginRequest.DisableDns = &disableDNS
|
||||||
|
}
|
||||||
|
if cmd.Flag(disableFirewallFlag).Changed {
|
||||||
|
loginRequest.DisableFirewall = &disableFirewall
|
||||||
|
}
|
||||||
|
|
||||||
|
if cmd.Flag(blockLANAccessFlag).Changed {
|
||||||
|
loginRequest.BlockLanAccess = &blockLANAccess
|
||||||
|
}
|
||||||
|
|
||||||
var loginErr error
|
var loginErr error
|
||||||
|
|
||||||
var loginResp *proto.LoginResponse
|
var loginResp *proto.LoginResponse
|
||||||
@@ -395,6 +454,24 @@ func parseCustomDNSAddress(modified bool) ([]byte, error) {
|
|||||||
return parsed, nil
|
return parsed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func validateDnsLabels(labels []string) (domain.List, error) {
|
||||||
|
var (
|
||||||
|
domains domain.List
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(labels) == 0 {
|
||||||
|
return domains, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
domains, err = domain.ValidateDomains(labels)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to validate dns labels: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return domains, nil
|
||||||
|
}
|
||||||
|
|
||||||
func isValidAddrPort(input string) bool {
|
func isValidAddrPort(input string) bool {
|
||||||
if input == "" {
|
if input == "" {
|
||||||
return true
|
return true
|
||||||
|
|||||||
24
client/configs/configs.go
Normal file
24
client/configs/configs.go
Normal file
@@ -0,0 +1,24 @@
|
|||||||
|
package configs
|
||||||
|
|
||||||
|
import (
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"runtime"
|
||||||
|
)
|
||||||
|
|
||||||
|
var StateDir string
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
StateDir = os.Getenv("NB_STATE_DIR")
|
||||||
|
if StateDir != "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
switch runtime.GOOS {
|
||||||
|
case "windows":
|
||||||
|
StateDir = filepath.Join(os.Getenv("PROGRAMDATA"), "Netbird")
|
||||||
|
case "darwin", "linux":
|
||||||
|
StateDir = "/var/lib/netbird"
|
||||||
|
case "freebsd", "openbsd", "netbsd", "dragonfly":
|
||||||
|
StateDir = "/var/db/netbird"
|
||||||
|
}
|
||||||
|
}
|
||||||
167
client/embed/doc.go
Normal file
167
client/embed/doc.go
Normal file
@@ -0,0 +1,167 @@
|
|||||||
|
// Package embed provides a way to embed the NetBird client directly
|
||||||
|
// into Go programs without requiring a separate NetBird client installation.
|
||||||
|
package embed
|
||||||
|
|
||||||
|
// Basic Usage:
|
||||||
|
//
|
||||||
|
// client, err := embed.New(embed.Options{
|
||||||
|
// DeviceName: "my-service",
|
||||||
|
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||||
|
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
// if err := client.Start(ctx); err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Complete HTTP Server Example:
|
||||||
|
//
|
||||||
|
// package main
|
||||||
|
//
|
||||||
|
// import (
|
||||||
|
// "context"
|
||||||
|
// "fmt"
|
||||||
|
// "log"
|
||||||
|
// "net/http"
|
||||||
|
// "os"
|
||||||
|
// "os/signal"
|
||||||
|
// "syscall"
|
||||||
|
// "time"
|
||||||
|
//
|
||||||
|
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// func main() {
|
||||||
|
// // Create client with setup key and device name
|
||||||
|
// client, err := netbird.New(netbird.Options{
|
||||||
|
// DeviceName: "http-server",
|
||||||
|
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||||
|
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||||
|
// LogOutput: io.Discard,
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Start with timeout
|
||||||
|
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
// if err := client.Start(ctx); err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Create HTTP server
|
||||||
|
// mux := http.NewServeMux()
|
||||||
|
// mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
|
||||||
|
// fmt.Printf("Request from %s: %s %s\n", r.RemoteAddr, r.Method, r.URL.Path)
|
||||||
|
// fmt.Fprintf(w, "Hello from netbird!")
|
||||||
|
// })
|
||||||
|
//
|
||||||
|
// // Listen on netbird network
|
||||||
|
// l, err := client.ListenTCP(":8080")
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// server := &http.Server{Handler: mux}
|
||||||
|
// go func() {
|
||||||
|
// if err := server.Serve(l); !errors.Is(err, http.ErrServerClosed) {
|
||||||
|
// log.Printf("HTTP server error: %v", err)
|
||||||
|
// }
|
||||||
|
// }()
|
||||||
|
//
|
||||||
|
// log.Printf("HTTP server listening on netbird network port 8080")
|
||||||
|
//
|
||||||
|
// // Handle shutdown
|
||||||
|
// stop := make(chan os.Signal, 1)
|
||||||
|
// signal.Notify(stop, syscall.SIGINT, syscall.SIGTERM)
|
||||||
|
// <-stop
|
||||||
|
//
|
||||||
|
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
//
|
||||||
|
// if err := server.Shutdown(shutdownCtx); err != nil {
|
||||||
|
// log.Printf("HTTP shutdown error: %v", err)
|
||||||
|
// }
|
||||||
|
// if err := client.Stop(shutdownCtx); err != nil {
|
||||||
|
// log.Printf("Netbird shutdown error: %v", err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// Complete HTTP Client Example:
|
||||||
|
//
|
||||||
|
// package main
|
||||||
|
//
|
||||||
|
// import (
|
||||||
|
// "context"
|
||||||
|
// "fmt"
|
||||||
|
// "io"
|
||||||
|
// "log"
|
||||||
|
// "os"
|
||||||
|
// "time"
|
||||||
|
//
|
||||||
|
// netbird "github.com/netbirdio/netbird/client/embed"
|
||||||
|
// )
|
||||||
|
//
|
||||||
|
// func main() {
|
||||||
|
// // Create client with setup key and device name
|
||||||
|
// client, err := netbird.New(netbird.Options{
|
||||||
|
// DeviceName: "http-client",
|
||||||
|
// SetupKey: os.Getenv("NB_SETUP_KEY"),
|
||||||
|
// ManagementURL: os.Getenv("NB_MANAGEMENT_URL"),
|
||||||
|
// LogOutput: io.Discard,
|
||||||
|
// })
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Start with timeout
|
||||||
|
// ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
//
|
||||||
|
// if err := client.Start(ctx); err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// // Create HTTP client that uses netbird network
|
||||||
|
// httpClient := client.NewHTTPClient()
|
||||||
|
// httpClient.Timeout = 10 * time.Second
|
||||||
|
//
|
||||||
|
// // Make request to server in netbird network
|
||||||
|
// target := os.Getenv("NB_TARGET")
|
||||||
|
// resp, err := httpClient.Get(target)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
// defer resp.Body.Close()
|
||||||
|
//
|
||||||
|
// // Read and print response
|
||||||
|
// body, err := io.ReadAll(resp.Body)
|
||||||
|
// if err != nil {
|
||||||
|
// log.Fatal(err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// fmt.Printf("Response from server: %s\n", string(body))
|
||||||
|
//
|
||||||
|
// // Clean shutdown
|
||||||
|
// shutdownCtx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||||
|
// defer cancel()
|
||||||
|
//
|
||||||
|
// if err := client.Stop(shutdownCtx); err != nil {
|
||||||
|
// log.Printf("Netbird shutdown error: %v", err)
|
||||||
|
// }
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// The package provides several methods for network operations:
|
||||||
|
// - Dial: Creates outbound connections
|
||||||
|
// - ListenTCP: Creates TCP listeners
|
||||||
|
// - ListenUDP: Creates UDP listeners
|
||||||
|
//
|
||||||
|
// By default, the embed package uses userspace networking mode, which doesn't
|
||||||
|
// require root/admin privileges. For production deployments, consider setting
|
||||||
|
// appropriate config and state paths for persistence.
|
||||||
293
client/embed/embed.go
Normal file
293
client/embed/embed.go
Normal file
@@ -0,0 +1,293 @@
|
|||||||
|
package embed
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/http"
|
||||||
|
"net/netip"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
wgnetstack "golang.zx2c4.com/wireguard/tun/netstack"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
|
)
|
||||||
|
|
||||||
|
var ErrClientAlreadyStarted = errors.New("client already started")
|
||||||
|
var ErrClientNotStarted = errors.New("client not started")
|
||||||
|
|
||||||
|
// Client manages a netbird embedded client instance
|
||||||
|
type Client struct {
|
||||||
|
deviceName string
|
||||||
|
config *internal.Config
|
||||||
|
mu sync.Mutex
|
||||||
|
cancel context.CancelFunc
|
||||||
|
setupKey string
|
||||||
|
connect *internal.ConnectClient
|
||||||
|
}
|
||||||
|
|
||||||
|
// Options configures a new Client
|
||||||
|
type Options struct {
|
||||||
|
// DeviceName is this peer's name in the network
|
||||||
|
DeviceName string
|
||||||
|
// SetupKey is used for authentication
|
||||||
|
SetupKey string
|
||||||
|
// ManagementURL overrides the default management server URL
|
||||||
|
ManagementURL string
|
||||||
|
// PreSharedKey is the pre-shared key for the WireGuard interface
|
||||||
|
PreSharedKey string
|
||||||
|
// LogOutput is the output destination for logs (defaults to os.Stderr if nil)
|
||||||
|
LogOutput io.Writer
|
||||||
|
// LogLevel sets the logging level (defaults to info if empty)
|
||||||
|
LogLevel string
|
||||||
|
// NoUserspace disables the userspace networking mode. Needs admin/root privileges
|
||||||
|
NoUserspace bool
|
||||||
|
// ConfigPath is the path to the netbird config file. If empty, the config will be stored in memory and not persisted.
|
||||||
|
ConfigPath string
|
||||||
|
// StatePath is the path to the netbird state file
|
||||||
|
StatePath string
|
||||||
|
// DisableClientRoutes disables the client routes
|
||||||
|
DisableClientRoutes bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// New creates a new netbird embedded client
|
||||||
|
func New(opts Options) (*Client, error) {
|
||||||
|
if opts.LogOutput != nil {
|
||||||
|
logrus.SetOutput(opts.LogOutput)
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.LogLevel != "" {
|
||||||
|
level, err := logrus.ParseLevel(opts.LogLevel)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("parse log level: %w", err)
|
||||||
|
}
|
||||||
|
logrus.SetLevel(level)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !opts.NoUserspace {
|
||||||
|
if err := os.Setenv(netstack.EnvUseNetstackMode, "true"); err != nil {
|
||||||
|
return nil, fmt.Errorf("setenv: %w", err)
|
||||||
|
}
|
||||||
|
if err := os.Setenv(netstack.EnvSkipProxy, "true"); err != nil {
|
||||||
|
return nil, fmt.Errorf("setenv: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if opts.StatePath != "" {
|
||||||
|
// TODO: Disable state if path not provided
|
||||||
|
if err := os.Setenv("NB_DNS_STATE_FILE", opts.StatePath); err != nil {
|
||||||
|
return nil, fmt.Errorf("setenv: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
t := true
|
||||||
|
var config *internal.Config
|
||||||
|
var err error
|
||||||
|
input := internal.ConfigInput{
|
||||||
|
ConfigPath: opts.ConfigPath,
|
||||||
|
ManagementURL: opts.ManagementURL,
|
||||||
|
PreSharedKey: &opts.PreSharedKey,
|
||||||
|
DisableServerRoutes: &t,
|
||||||
|
DisableClientRoutes: &opts.DisableClientRoutes,
|
||||||
|
}
|
||||||
|
if opts.ConfigPath != "" {
|
||||||
|
config, err = internal.UpdateOrCreateConfig(input)
|
||||||
|
} else {
|
||||||
|
config, err = internal.CreateInMemoryConfig(input)
|
||||||
|
}
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("create config: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Client{
|
||||||
|
deviceName: opts.DeviceName,
|
||||||
|
setupKey: opts.SetupKey,
|
||||||
|
config: config,
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start begins client operation and blocks until the engine has been started successfully or a startup error occurs.
|
||||||
|
// Pass a context with a deadline to limit the time spent waiting for the engine to start.
|
||||||
|
func (c *Client) Start(startCtx context.Context) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
if c.cancel != nil {
|
||||||
|
return ErrClientAlreadyStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx := internal.CtxInitState(context.Background())
|
||||||
|
// nolint:staticcheck
|
||||||
|
ctx = context.WithValue(ctx, system.DeviceNameCtxKey, c.deviceName)
|
||||||
|
if err := internal.Login(ctx, c.config, c.setupKey, ""); err != nil {
|
||||||
|
return fmt.Errorf("login: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
recorder := peer.NewRecorder(c.config.ManagementURL.String())
|
||||||
|
client := internal.NewConnectClient(ctx, c.config, recorder)
|
||||||
|
|
||||||
|
// either startup error (permanent backoff err) or nil err (successful engine up)
|
||||||
|
// TODO: make after-startup backoff err available
|
||||||
|
run := make(chan struct{}, 1)
|
||||||
|
clientErr := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
if err := client.Run(run); err != nil {
|
||||||
|
clientErr <- err
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-startCtx.Done():
|
||||||
|
if stopErr := client.Stop(); stopErr != nil {
|
||||||
|
return fmt.Errorf("stop error after context done. Stop error: %w. Context done: %w", stopErr, startCtx.Err())
|
||||||
|
}
|
||||||
|
return startCtx.Err()
|
||||||
|
case err := <-clientErr:
|
||||||
|
return fmt.Errorf("startup: %w", err)
|
||||||
|
case <-run:
|
||||||
|
}
|
||||||
|
|
||||||
|
c.connect = client
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully stops the client.
|
||||||
|
// Pass a context with a deadline to limit the time spent waiting for the engine to stop.
|
||||||
|
func (c *Client) Stop(ctx context.Context) error {
|
||||||
|
c.mu.Lock()
|
||||||
|
defer c.mu.Unlock()
|
||||||
|
|
||||||
|
if c.connect == nil {
|
||||||
|
return ErrClientNotStarted
|
||||||
|
}
|
||||||
|
|
||||||
|
done := make(chan error, 1)
|
||||||
|
go func() {
|
||||||
|
done <- c.connect.Stop()
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
c.cancel = nil
|
||||||
|
return ctx.Err()
|
||||||
|
case err := <-done:
|
||||||
|
c.cancel = nil
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("stop: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Dial dials a network address in the netbird network.
|
||||||
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
|
func (c *Client) Dial(ctx context.Context, network, address string) (net.Conn, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
connect := c.connect
|
||||||
|
if connect == nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, ErrClientNotStarted
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
engine := connect.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, errors.New("engine not started")
|
||||||
|
}
|
||||||
|
|
||||||
|
nsnet, err := engine.GetNet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get net: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsnet.DialContext(ctx, network, address)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenTCP listens on the given address in the netbird network
|
||||||
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
|
func (c *Client) ListenTCP(address string) (net.Listener, error) {
|
||||||
|
nsnet, addr, err := c.getNet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("split host port: %w", err)
|
||||||
|
}
|
||||||
|
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||||
|
|
||||||
|
tcpAddr, err := net.ResolveTCPAddr("tcp", listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve: %w", err)
|
||||||
|
}
|
||||||
|
return nsnet.ListenTCP(tcpAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ListenUDP listens on the given address in the netbird network
|
||||||
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
|
func (c *Client) ListenUDP(address string) (net.PacketConn, error) {
|
||||||
|
nsnet, addr, err := c.getNet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
_, port, err := net.SplitHostPort(address)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("split host port: %w", err)
|
||||||
|
}
|
||||||
|
listenAddr := fmt.Sprintf("%s:%s", addr, port)
|
||||||
|
|
||||||
|
udpAddr, err := net.ResolveUDPAddr("udp", listenAddr)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("resolve: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsnet.ListenUDP(udpAddr)
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewHTTPClient returns a configured http.Client that uses the netbird network for requests.
|
||||||
|
// Not applicable if the userspace networking mode is disabled.
|
||||||
|
func (c *Client) NewHTTPClient() *http.Client {
|
||||||
|
transport := &http.Transport{
|
||||||
|
DialContext: c.Dial,
|
||||||
|
}
|
||||||
|
|
||||||
|
return &http.Client{
|
||||||
|
Transport: transport,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Client) getNet() (*wgnetstack.Net, netip.Addr, error) {
|
||||||
|
c.mu.Lock()
|
||||||
|
connect := c.connect
|
||||||
|
if connect == nil {
|
||||||
|
c.mu.Unlock()
|
||||||
|
return nil, netip.Addr{}, errors.New("client not started")
|
||||||
|
}
|
||||||
|
c.mu.Unlock()
|
||||||
|
|
||||||
|
engine := connect.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, netip.Addr{}, errors.New("engine not started")
|
||||||
|
}
|
||||||
|
|
||||||
|
addr, err := engine.Address()
|
||||||
|
if err != nil {
|
||||||
|
return nil, netip.Addr{}, fmt.Errorf("engine address: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
nsnet, err := engine.GetNet()
|
||||||
|
if err != nil {
|
||||||
|
return nil, netip.Addr{}, fmt.Errorf("get net: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nsnet, addr, nil
|
||||||
|
}
|
||||||
@@ -10,17 +10,18 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface)
|
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,12 +34,12 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
// for the userspace packet filtering firewall
|
// for the userspace packet filtering firewall
|
||||||
fm, err := createNativeFirewall(iface, stateManager)
|
fm, err := createNativeFirewall(iface, stateManager, disableServerRoutes)
|
||||||
|
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return fm, err
|
return fm, err
|
||||||
@@ -47,10 +48,10 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewal
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||||
fm, err := createFW(iface)
|
fm, err := createFW(iface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create firewall: %s", err)
|
return nil, fmt.Errorf("create firewall: %s", err)
|
||||||
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
|||||||
@@ -1,13 +1,18 @@
|
|||||||
package firewall
|
package firewall
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// IFaceMapper defines subset methods of interface required for manager
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
type IFaceMapper interface {
|
type IFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() device.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
SetFilter(device.PacketFilter) error
|
SetFilter(device.PacketFilter) error
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,7 +3,7 @@ package iptables
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"strconv"
|
"slices"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
@@ -19,8 +19,7 @@ const (
|
|||||||
tableName = "filter"
|
tableName = "filter"
|
||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
chainNameInputRules = "NETBIRD-ACL-INPUT"
|
||||||
chainNameOutputRules = "NETBIRD-ACL-OUTPUT"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type aclEntries map[string][][]string
|
type aclEntries map[string][][]string
|
||||||
@@ -31,10 +30,8 @@ type entry struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type aclManager struct {
|
type aclManager struct {
|
||||||
iptablesClient *iptables.IPTables
|
iptablesClient *iptables.IPTables
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
|
||||||
|
|
||||||
entries aclEntries
|
entries aclEntries
|
||||||
optionalEntries map[string][]entry
|
optionalEntries map[string][]entry
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
@@ -42,12 +39,10 @@ type aclManager struct {
|
|||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
}
|
}
|
||||||
|
|
||||||
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper, routingFwChainName string) (*aclManager, error) {
|
func newAclManager(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*aclManager, error) {
|
||||||
m := &aclManager{
|
m := &aclManager{
|
||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
routingFwChainName: routingFwChainName,
|
|
||||||
|
|
||||||
entries: make(map[string][][]string),
|
entries: make(map[string][][]string),
|
||||||
optionalEntries: make(map[string][]entry),
|
optionalEntries: make(map[string][]entry),
|
||||||
ipsetStore: newIpsetStore(),
|
ipsetStore: newIpsetStore(),
|
||||||
@@ -80,32 +75,27 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) AddPeerFiltering(
|
func (m *aclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var dPortVal, sPortVal string
|
chain := chainNameInputRules
|
||||||
if dPort != nil && dPort.Values != nil {
|
|
||||||
// TODO: we support only one port per rule in current implementation of ACLs
|
|
||||||
dPortVal = strconv.Itoa(dPort.Values[0])
|
|
||||||
}
|
|
||||||
if sPort != nil && sPort.Values != nil {
|
|
||||||
sPortVal = strconv.Itoa(sPort.Values[0])
|
|
||||||
}
|
|
||||||
|
|
||||||
var chain string
|
ipsetName = transformIPsetName(ipsetName, sPort, dPort)
|
||||||
if direction == firewall.RuleDirectionOUT {
|
specs := filterRuleSpecs(ip, string(protocol), sPort, dPort, action, ipsetName)
|
||||||
chain = chainNameOutputRules
|
|
||||||
} else {
|
|
||||||
chain = chainNameInputRules
|
|
||||||
}
|
|
||||||
|
|
||||||
ipsetName = transformIPsetName(ipsetName, sPortVal, dPortVal)
|
mangleSpecs := slices.Clone(specs)
|
||||||
specs := filterRuleSpecs(ip, string(protocol), sPortVal, dPortVal, direction, action, ipsetName)
|
mangleSpecs = append(mangleSpecs,
|
||||||
|
"-i", m.wgIface.Name(),
|
||||||
|
"-m", "addrtype", "--dst-type", "LOCAL",
|
||||||
|
"-j", "MARK", "--set-xmark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected),
|
||||||
|
)
|
||||||
|
|
||||||
|
specs = append(specs, "-j", actionToStr(action))
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
if ipList, ipsetExists := m.ipsetStore.ipset(ipsetName); ipsetExists {
|
||||||
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
if err := ipset.Add(ipsetName, ip.String()); err != nil {
|
||||||
@@ -137,7 +127,7 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
m.ipsetStore.addIpList(ipsetName, ipList)
|
m.ipsetStore.addIpList(ipsetName, ipList)
|
||||||
}
|
}
|
||||||
|
|
||||||
ok, err := m.iptablesClient.Exists("filter", chain, specs...)
|
ok, err := m.iptablesClient.Exists(tableFilter, chain, specs...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("failed to check rule: %w", err)
|
return nil, fmt.Errorf("failed to check rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -145,16 +135,22 @@ func (m *aclManager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("rule already exists")
|
return nil, fmt.Errorf("rule already exists")
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.iptablesClient.Append("filter", chain, specs...); err != nil {
|
if err := m.iptablesClient.Append(tableFilter, chain, specs...); err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := m.iptablesClient.Append(tableMangle, chainRTPRE, mangleSpecs...); err != nil {
|
||||||
|
log.Errorf("failed to add mangle rule: %v", err)
|
||||||
|
mangleSpecs = nil
|
||||||
|
}
|
||||||
|
|
||||||
rule := &Rule{
|
rule := &Rule{
|
||||||
ruleID: uuid.New().String(),
|
ruleID: uuid.New().String(),
|
||||||
specs: specs,
|
specs: specs,
|
||||||
ipsetName: ipsetName,
|
mangleSpecs: mangleSpecs,
|
||||||
ip: ip.String(),
|
ipsetName: ipsetName,
|
||||||
chain: chain,
|
ip: ip.String(),
|
||||||
|
chain: chain,
|
||||||
}
|
}
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
@@ -197,6 +193,12 @@ func (m *aclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
return fmt.Errorf("failed to delete rule: %s, %v: %w", r.chain, r.specs, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if r.mangleSpecs != nil {
|
||||||
|
if err := m.iptablesClient.Delete(tableMangle, chainRTPRE, r.mangleSpecs...); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
m.updateState()
|
m.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -214,28 +216,7 @@ func (m *aclManager) Reset() error {
|
|||||||
|
|
||||||
// todo write less destructive cleanup mechanism
|
// todo write less destructive cleanup mechanism
|
||||||
func (m *aclManager) cleanChains() error {
|
func (m *aclManager) cleanChains() error {
|
||||||
ok, err := m.iptablesClient.ChainExists(tableName, chainNameOutputRules)
|
ok, err := m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to list chains: %s", err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
if ok {
|
|
||||||
rules := m.entries["OUTPUT"]
|
|
||||||
for _, rule := range rules {
|
|
||||||
err := m.iptablesClient.DeleteIfExists(tableName, "OUTPUT", rule...)
|
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v, %s", rule, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
err = m.iptablesClient.ClearAndDeleteChain(tableName, chainNameOutputRules)
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to clear and delete %s chain: %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ok, err = m.iptablesClient.ChainExists(tableName, chainNameInputRules)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("failed to list chains: %s", err)
|
log.Debugf("failed to list chains: %s", err)
|
||||||
return err
|
return err
|
||||||
@@ -295,12 +276,6 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// chain netbird-acl-output-rules
|
|
||||||
if err := m.iptablesClient.NewChain(tableName, chainNameOutputRules); err != nil {
|
|
||||||
log.Debugf("failed to create '%s' chain: %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
for chainName, rules := range m.entries {
|
for chainName, rules := range m.entries {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
if err := m.iptablesClient.InsertUnique(tableName, chainName, 1, rule...); err != nil {
|
||||||
@@ -329,40 +304,28 @@ func (m *aclManager) createDefaultChains() error {
|
|||||||
|
|
||||||
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
// The existing FORWARD rules/policies decide outbound traffic towards our interface.
|
||||||
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
// In case the FORWARD policy is set to "drop", we add an established/related rule to allow return traffic for the inbound rule.
|
||||||
|
|
||||||
// The OUTPUT chain gets an extra rule to allow traffic to any set up routes, the return traffic is handled by the INPUT related/established rule.
|
|
||||||
func (m *aclManager) seedInitialEntries() {
|
func (m *aclManager) seedInitialEntries() {
|
||||||
|
|
||||||
established := getConntrackEstablished()
|
established := getConntrackEstablished()
|
||||||
|
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
m.appendToEntries("INPUT", []string{"-i", m.wgIface.Name(), "-j", chainNameInputRules})
|
||||||
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
m.appendToEntries("INPUT", append([]string{"-i", m.wgIface.Name()}, established...))
|
||||||
|
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", "DROP"})
|
// Inbound is handled by our ACLs, the rest is dropped.
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "-j", chainNameOutputRules})
|
// For outbound we respect the FORWARD policy. However, we need to allow established/related traffic for inbound rules.
|
||||||
m.appendToEntries("OUTPUT", []string{"-o", m.wgIface.Name(), "!", "-d", m.wgIface.Address().String(), "-j", "ACCEPT"})
|
|
||||||
m.appendToEntries("OUTPUT", append([]string{"-o", m.wgIface.Name()}, established...))
|
|
||||||
|
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", "DROP"})
|
||||||
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", m.routingFwChainName})
|
|
||||||
m.appendToEntries("FORWARD", append([]string{"-o", m.wgIface.Name()}, established...))
|
m.appendToEntries("FORWARD", []string{"-o", m.wgIface.Name(), "-j", chainRTFWDOUT})
|
||||||
|
m.appendToEntries("FORWARD", []string{"-i", m.wgIface.Name(), "-j", chainRTFWDIN})
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) seedInitialOptionalEntries() {
|
func (m *aclManager) seedInitialOptionalEntries() {
|
||||||
m.optionalEntries["FORWARD"] = []entry{
|
m.optionalEntries["FORWARD"] = []entry{
|
||||||
{
|
{
|
||||||
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", chainNameInputRules},
|
spec: []string{"-m", "mark", "--mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected), "-j", "ACCEPT"},
|
||||||
position: 2,
|
position: 2,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m.optionalEntries["PREROUTING"] = []entry{
|
|
||||||
{
|
|
||||||
spec: []string{"-t", "mangle", "-i", m.wgIface.Name(), "-m", "addrtype", "--dst-type", "LOCAL", "-j", "MARK", "--set-mark", fmt.Sprintf("%#x", nbnet.PreroutingFwmarkRedirected)},
|
|
||||||
position: 1,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
func (m *aclManager) appendToEntries(chainName string, spec []string) {
|
||||||
@@ -396,42 +359,26 @@ func (m *aclManager) updateState() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// filterRuleSpecs returns the specs of a filtering rule
|
// filterRuleSpecs returns the specs of a filtering rule
|
||||||
func filterRuleSpecs(
|
func filterRuleSpecs(ip net.IP, protocol string, sPort, dPort *firewall.Port, action firewall.Action, ipsetName string) (specs []string) {
|
||||||
ip net.IP, protocol string, sPort, dPort string, direction firewall.RuleDirection, action firewall.Action, ipsetName string,
|
|
||||||
) (specs []string) {
|
|
||||||
matchByIP := true
|
matchByIP := true
|
||||||
// don't use IP matching if IP is ip 0.0.0.0
|
// don't use IP matching if IP is ip 0.0.0.0
|
||||||
if ip.String() == "0.0.0.0" {
|
if ip.String() == "0.0.0.0" {
|
||||||
matchByIP = false
|
matchByIP = false
|
||||||
}
|
}
|
||||||
switch direction {
|
|
||||||
case firewall.RuleDirectionIN:
|
if matchByIP {
|
||||||
if matchByIP {
|
if ipsetName != "" {
|
||||||
if ipsetName != "" {
|
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "src")
|
} else {
|
||||||
} else {
|
specs = append(specs, "-s", ip.String())
|
||||||
specs = append(specs, "-s", ip.String())
|
|
||||||
}
|
|
||||||
}
|
|
||||||
case firewall.RuleDirectionOUT:
|
|
||||||
if matchByIP {
|
|
||||||
if ipsetName != "" {
|
|
||||||
specs = append(specs, "-m", "set", "--set", ipsetName, "dst")
|
|
||||||
} else {
|
|
||||||
specs = append(specs, "-d", ip.String())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if protocol != "all" {
|
if protocol != "all" {
|
||||||
specs = append(specs, "-p", protocol)
|
specs = append(specs, "-p", protocol)
|
||||||
}
|
}
|
||||||
if sPort != "" {
|
specs = append(specs, applyPort("--sport", sPort)...)
|
||||||
specs = append(specs, "--sport", sPort)
|
specs = append(specs, applyPort("--dport", dPort)...)
|
||||||
}
|
return specs
|
||||||
if dPort != "" {
|
|
||||||
specs = append(specs, "--dport", dPort)
|
|
||||||
}
|
|
||||||
return append(specs, "-j", actionToStr(action))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func actionToStr(action firewall.Action) string {
|
func actionToStr(action firewall.Action) string {
|
||||||
@@ -441,15 +388,15 @@ func actionToStr(action firewall.Action) string {
|
|||||||
return "DROP"
|
return "DROP"
|
||||||
}
|
}
|
||||||
|
|
||||||
func transformIPsetName(ipsetName string, sPort, dPort string) string {
|
func transformIPsetName(ipsetName string, sPort, dPort *firewall.Port) string {
|
||||||
switch {
|
switch {
|
||||||
case ipsetName == "":
|
case ipsetName == "":
|
||||||
return ""
|
return ""
|
||||||
case sPort != "" && dPort != "":
|
case sPort != nil && dPort != nil:
|
||||||
return ipsetName + "-sport-dport"
|
return ipsetName + "-sport-dport"
|
||||||
case sPort != "":
|
case sPort != nil:
|
||||||
return ipsetName + "-sport"
|
return ipsetName + "-sport"
|
||||||
case dPort != "":
|
case dPort != nil:
|
||||||
return ipsetName + "-dport"
|
return ipsetName + "-dport"
|
||||||
default:
|
default:
|
||||||
return ipsetName
|
return ipsetName
|
||||||
|
|||||||
@@ -13,7 +13,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -31,7 +31,7 @@ type Manager struct {
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -52,7 +52,7 @@ func Create(wgIface iFaceMapper) (*Manager, error) {
|
|||||||
return nil, fmt.Errorf("create router: %w", err)
|
return nil, fmt.Errorf("create router: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.aclMgr, err = newAclManager(iptablesClient, wgIface, chainRTFWD)
|
m.aclMgr, err = newAclManager(iptablesClient, wgIface)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create acl manager: %w", err)
|
return nil, fmt.Errorf("create acl manager: %w", err)
|
||||||
}
|
}
|
||||||
@@ -96,22 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, direction, action, ipsetName)
|
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -126,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -167,7 +167,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -197,34 +197,52 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
net.ParseIP("0.0.0.0"),
|
nil,
|
||||||
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.RuleDirectionIN,
|
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("failed to allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
}
|
}
|
||||||
_, err = m.AddPeerFiltering(
|
return nil
|
||||||
net.ParseIP("0.0.0.0"),
|
|
||||||
"all",
|
|
||||||
nil,
|
|
||||||
nil,
|
|
||||||
firewall.RuleDirectionOUT,
|
|
||||||
firewall.ActionAccept,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Flush doesn't need to be implemented for this manager
|
// Flush doesn't need to be implemented for this manager
|
||||||
func (m *Manager) Flush() error { return nil }
|
func (m *Manager) Flush() error { return nil }
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(log.Level) {
|
||||||
|
// not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,15 +10,15 @@ import (
|
|||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ifaceMock = &iFaceMock{
|
var ifaceMock = &iFaceMock{
|
||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -31,7 +31,7 @@ var ifaceMock = &iFaceMock{
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMock struct {
|
type iFaceMock struct {
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
func (i *iFaceMock) Name() string {
|
||||||
@@ -41,7 +41,7 @@ func (i *iFaceMock) Name() string {
|
|||||||
panic("NameFunc is not set")
|
panic("NameFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
func (i *iFaceMock) Address() wgaddr.Address {
|
||||||
if i.AddressFunc != nil {
|
if i.AddressFunc != nil {
|
||||||
return i.AddressFunc()
|
return i.AddressFunc()
|
||||||
}
|
}
|
||||||
@@ -62,33 +62,20 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset(nil)
|
err := manager.Close(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule1 []fw.Rule
|
|
||||||
t.Run("add first rule", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.2")
|
|
||||||
port := &fw.Port{Values: []int{8080}}
|
|
||||||
rule1, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
for _, r := range rule1 {
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
|
||||||
}
|
|
||||||
|
|
||||||
})
|
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []int{8043: 8046},
|
IsRange: true,
|
||||||
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -97,15 +84,6 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
|
||||||
for _, r := range rule1 {
|
|
||||||
err := manager.DeletePeerRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, false, r.(*Rule).specs...)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
err := manager.DeletePeerRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
@@ -118,32 +96,29 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []int{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Reset(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
|
|
||||||
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
ok, err := ipv4Client.ChainExists("filter", chainNameInputRules)
|
||||||
require.NoError(t, err, "failed check chain exists")
|
require.NoError(t, err, "failed check chain exists")
|
||||||
|
|
||||||
if ok {
|
if ok {
|
||||||
require.NoErrorf(t, err, "chain '%v' still exists after Reset", chainNameInputRules)
|
require.NoErrorf(t, err, "chain '%v' still exists after Close", chainNameInputRules)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestIptablesManagerIPSet(t *testing.T) {
|
func TestIptablesManagerIPSet(t *testing.T) {
|
||||||
ipv4Client, err := iptables.NewWithProtocol(iptables.ProtocolIPv4)
|
|
||||||
require.NoError(t, err)
|
|
||||||
|
|
||||||
mock := &iFaceMock{
|
mock := &iFaceMock{
|
||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -161,39 +136,19 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset(nil)
|
err := manager.Close(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var rule1 []fw.Rule
|
|
||||||
t.Run("add first rule with set", func(t *testing.T) {
|
|
||||||
ip := net.ParseIP("10.20.0.2")
|
|
||||||
port := &fw.Port{Values: []int{8080}}
|
|
||||||
rule1, err = manager.AddPeerFiltering(
|
|
||||||
ip, "tcp", nil, port, fw.RuleDirectionOUT,
|
|
||||||
fw.ActionAccept, "default", "accept HTTP traffic",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
|
||||||
|
|
||||||
for _, r := range rule1 {
|
|
||||||
checkRuleSpecs(t, ipv4Client, chainNameOutputRules, true, r.(*Rule).specs...)
|
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-dport", "ipset name must be set")
|
|
||||||
require.Equal(t, r.(*Rule).ip, "10.20.0.2", "ipset IP must be set")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
var rule2 []fw.Rule
|
var rule2 []fw.Rule
|
||||||
t.Run("add second rule", func(t *testing.T) {
|
t.Run("add second rule", func(t *testing.T) {
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []int{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
ip, "tcp", port, nil, fw.RuleDirectionIN, fw.ActionAccept,
|
|
||||||
"default", "accept HTTPS traffic from ports range",
|
|
||||||
)
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -201,15 +156,6 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
t.Run("delete first rule", func(t *testing.T) {
|
|
||||||
for _, r := range rule1 {
|
|
||||||
err := manager.DeletePeerRule(r)
|
|
||||||
require.NoError(t, err, "failed to delete rule")
|
|
||||||
|
|
||||||
require.NotContains(t, manager.aclMgr.ipsetStore.ipsets, r.(*Rule).ruleID, "rule must be removed form the ruleset index")
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
t.Run("delete second rule", func(t *testing.T) {
|
t.Run("delete second rule", func(t *testing.T) {
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
err := manager.DeletePeerRule(r)
|
err := manager.DeletePeerRule(r)
|
||||||
@@ -220,7 +166,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("reset check", func(t *testing.T) {
|
t.Run("reset check", func(t *testing.T) {
|
||||||
err = manager.Reset(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -238,8 +184,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("10.20.0.1"),
|
IP: net.ParseIP("10.20.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("10.20.0.0"),
|
IP: net.ParseIP("10.20.0.0"),
|
||||||
@@ -258,7 +204,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err := manager.Reset(nil)
|
err := manager.Close(nil)
|
||||||
require.NoError(t, err, "clear the manager state")
|
require.NoError(t, err, "clear the manager state")
|
||||||
|
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -269,12 +215,8 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
if i%2 == 0 {
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,8 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
@@ -23,22 +24,36 @@ import (
|
|||||||
|
|
||||||
// constants needed to manage and create iptable rules
|
// constants needed to manage and create iptable rules
|
||||||
const (
|
const (
|
||||||
tableFilter = "filter"
|
tableFilter = "filter"
|
||||||
tableNat = "nat"
|
tableNat = "nat"
|
||||||
tableMangle = "mangle"
|
tableMangle = "mangle"
|
||||||
|
|
||||||
chainPOSTROUTING = "POSTROUTING"
|
chainPOSTROUTING = "POSTROUTING"
|
||||||
chainPREROUTING = "PREROUTING"
|
chainPREROUTING = "PREROUTING"
|
||||||
chainRTNAT = "NETBIRD-RT-NAT"
|
chainRTNAT = "NETBIRD-RT-NAT"
|
||||||
chainRTFWD = "NETBIRD-RT-FWD"
|
chainRTFWDIN = "NETBIRD-RT-FWD-IN"
|
||||||
|
chainRTFWDOUT = "NETBIRD-RT-FWD-OUT"
|
||||||
chainRTPRE = "NETBIRD-RT-PRE"
|
chainRTPRE = "NETBIRD-RT-PRE"
|
||||||
|
chainRTRDR = "NETBIRD-RT-RDR"
|
||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
|
|
||||||
jumpPre = "jump-pre"
|
jumpManglePre = "jump-mangle-pre"
|
||||||
jumpNat = "jump-nat"
|
jumpNatPre = "jump-nat-pre"
|
||||||
matchSet = "--match-set"
|
jumpNatPost = "jump-nat-post"
|
||||||
|
matchSet = "--match-set"
|
||||||
|
|
||||||
|
dnatSuffix = "_dnat"
|
||||||
|
snatSuffix = "_snat"
|
||||||
|
fwdSuffix = "_fwd"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type ruleInfo struct {
|
||||||
|
chain string
|
||||||
|
table string
|
||||||
|
rule []string
|
||||||
|
}
|
||||||
|
|
||||||
type routeFilteringRuleParams struct {
|
type routeFilteringRuleParams struct {
|
||||||
Sources []netip.Prefix
|
Sources []netip.Prefix
|
||||||
Destination netip.Prefix
|
Destination netip.Prefix
|
||||||
@@ -62,6 +77,7 @@ type router struct {
|
|||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
|
|
||||||
stateManager *statemanager.Manager
|
stateManager *statemanager.Manager
|
||||||
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router, error) {
|
||||||
@@ -69,6 +85,7 @@ func newRouter(iptablesClient *iptables.IPTables, wgIface iFaceMapper) (*router,
|
|||||||
iptablesClient: iptablesClient,
|
iptablesClient: iptablesClient,
|
||||||
rules: make(map[string][]string),
|
rules: make(map[string][]string),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
@@ -104,6 +121,7 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -111,7 +129,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
@@ -135,7 +153,16 @@ func (r *router) AddRouteFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
rule := genRouteFilteringRuleSpec(params)
|
rule := genRouteFilteringRuleSpec(params)
|
||||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
var err error
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
// after the established rule
|
||||||
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||||
|
} else {
|
||||||
|
err = r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add route rule: %v", err)
|
return nil, fmt.Errorf("add route rule: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -147,12 +174,12 @@ func (r *router) AddRouteFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
ruleKey := rule.GetRuleID()
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
setName := r.findSetNameInRule(rule)
|
setName := r.findSetNameInRule(rule)
|
||||||
|
|
||||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("delete route rule: %v", err)
|
return fmt.Errorf("delete route rule: %v", err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
@@ -203,6 +230,10 @@ func (r *router) deleteIpSet(setName string) error {
|
|||||||
|
|
||||||
// AddNatRule inserts an iptables rule pair into the nat chain
|
// AddNatRule inserts an iptables rule pair into the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if r.legacyManagement {
|
if r.legacyManagement {
|
||||||
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
log.Warnf("This peer is connected to a NetBird Management service with an older version. Allowing all traffic for %s", pair.Destination)
|
||||||
if err := r.addLegacyRouteRule(pair); err != nil {
|
if err := r.addLegacyRouteRule(pair); err != nil {
|
||||||
@@ -229,6 +260,10 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
// RemoveNatRule removes an iptables rule pair from forwarding and nat chains
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -255,7 +290,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
rule := []string{"-s", pair.Source.String(), "-d", pair.Destination.String(), "-j", routingFinalForwardJump}
|
||||||
if err := r.iptablesClient.Append(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.Append(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("add legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -268,7 +303,7 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
@@ -296,7 +331,7 @@ func (r *router) RemoveAllLegacyRouteRules() error {
|
|||||||
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
if !strings.HasPrefix(k, firewall.ForwardingFormatPrefix) {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWD, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
merr = multierror.Append(merr, fmt.Errorf("remove legacy forwarding rule: %v", err))
|
||||||
} else {
|
} else {
|
||||||
delete(r.rules, k)
|
delete(r.rules, k)
|
||||||
@@ -334,9 +369,11 @@ func (r *router) cleanUpDefaultForwardRules() error {
|
|||||||
chain string
|
chain string
|
||||||
table string
|
table string
|
||||||
}{
|
}{
|
||||||
{chainRTFWD, tableFilter},
|
{chainRTFWDIN, tableFilter},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTFWDOUT, tableFilter},
|
||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTRDR, tableNat},
|
||||||
} {
|
} {
|
||||||
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
ok, err := r.iptablesClient.ChainExists(chainInfo.table, chainInfo.chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -356,16 +393,22 @@ func (r *router) createContainers() error {
|
|||||||
chain string
|
chain string
|
||||||
table string
|
table string
|
||||||
}{
|
}{
|
||||||
{chainRTFWD, tableFilter},
|
{chainRTFWDIN, tableFilter},
|
||||||
|
{chainRTFWDOUT, tableFilter},
|
||||||
{chainRTPRE, tableMangle},
|
{chainRTPRE, tableMangle},
|
||||||
{chainRTNAT, tableNat},
|
{chainRTNAT, tableNat},
|
||||||
|
{chainRTRDR, tableNat},
|
||||||
} {
|
} {
|
||||||
if err := r.createAndSetupChain(chainInfo.chain); err != nil {
|
if err := r.iptablesClient.NewChain(chainInfo.table, chainInfo.chain); err != nil {
|
||||||
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
return fmt.Errorf("create chain %s in table %s: %w", chainInfo.chain, chainInfo.table, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.insertEstablishedRule(chainRTFWD); err != nil {
|
if err := r.insertEstablishedRule(chainRTFWDIN); err != nil {
|
||||||
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.insertEstablishedRule(chainRTFWDOUT); err != nil {
|
||||||
return fmt.Errorf("insert established rule: %w", err)
|
return fmt.Errorf("insert established rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -406,27 +449,6 @@ func (r *router) addPostroutingRules() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createAndSetupChain(chain string) error {
|
|
||||||
table := r.getTableForChain(chain)
|
|
||||||
|
|
||||||
if err := r.iptablesClient.NewChain(table, chain); err != nil {
|
|
||||||
return fmt.Errorf("failed creating chain %s, error: %v", chain, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) getTableForChain(chain string) string {
|
|
||||||
switch chain {
|
|
||||||
case chainRTNAT:
|
|
||||||
return tableNat
|
|
||||||
case chainRTPRE:
|
|
||||||
return tableMangle
|
|
||||||
default:
|
|
||||||
return tableFilter
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *router) insertEstablishedRule(chain string) error {
|
func (r *router) insertEstablishedRule(chain string) error {
|
||||||
establishedRule := getConntrackEstablished()
|
establishedRule := getConntrackEstablished()
|
||||||
|
|
||||||
@@ -445,28 +467,43 @@ func (r *router) addJumpRules() error {
|
|||||||
// Jump to NAT chain
|
// Jump to NAT chain
|
||||||
natRule := []string{"-j", chainRTNAT}
|
natRule := []string{"-j", chainRTNAT}
|
||||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||||
return fmt.Errorf("add nat jump rule: %v", err)
|
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||||
}
|
}
|
||||||
r.rules[jumpNat] = natRule
|
r.rules[jumpNatPost] = natRule
|
||||||
|
|
||||||
// Jump to prerouting chain
|
// Jump to mangle prerouting chain
|
||||||
preRule := []string{"-j", chainRTPRE}
|
preRule := []string{"-j", chainRTPRE}
|
||||||
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
if err := r.iptablesClient.Insert(tableMangle, chainPREROUTING, 1, preRule...); err != nil {
|
||||||
return fmt.Errorf("add prerouting jump rule: %v", err)
|
return fmt.Errorf("add mangle prerouting jump rule: %v", err)
|
||||||
}
|
}
|
||||||
r.rules[jumpPre] = preRule
|
r.rules[jumpManglePre] = preRule
|
||||||
|
|
||||||
|
// Jump to nat prerouting chain
|
||||||
|
rdrRule := []string{"-j", chainRTRDR}
|
||||||
|
if err := r.iptablesClient.Insert(tableNat, chainPREROUTING, 1, rdrRule...); err != nil {
|
||||||
|
return fmt.Errorf("add nat prerouting jump rule: %v", err)
|
||||||
|
}
|
||||||
|
r.rules[jumpNatPre] = rdrRule
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) cleanJumpRules() error {
|
func (r *router) cleanJumpRules() error {
|
||||||
for _, ruleKey := range []string{jumpNat, jumpPre} {
|
for _, ruleKey := range []string{jumpNatPost, jumpManglePre, jumpNatPre} {
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
table := tableNat
|
var table, chain string
|
||||||
chain := chainPOSTROUTING
|
switch ruleKey {
|
||||||
if ruleKey == jumpPre {
|
case jumpNatPost:
|
||||||
|
table = tableNat
|
||||||
|
chain = chainPOSTROUTING
|
||||||
|
case jumpManglePre:
|
||||||
table = tableMangle
|
table = tableMangle
|
||||||
chain = chainPREROUTING
|
chain = chainPREROUTING
|
||||||
|
case jumpNatPre:
|
||||||
|
table = tableNat
|
||||||
|
chain = chainPREROUTING
|
||||||
|
default:
|
||||||
|
return fmt.Errorf("unknown jump rule: %s", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
if err := r.iptablesClient.DeleteIfExists(table, chain, rule...); err != nil {
|
||||||
@@ -511,6 +548,8 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = rule
|
r.rules[ruleKey] = rule
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -526,6 +565,7 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
log.Debugf("marking rule %s not found", ruleKey)
|
log.Debugf("marking rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -555,6 +595,137 @@ func (r *router) updateState() {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
toDestination := rule.TranslatedAddress.String()
|
||||||
|
switch {
|
||||||
|
case len(rule.TranslatedPort.Values) == 0:
|
||||||
|
// no translated port, use original port
|
||||||
|
case len(rule.TranslatedPort.Values) == 1:
|
||||||
|
toDestination += fmt.Sprintf(":%d", rule.TranslatedPort.Values[0])
|
||||||
|
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||||
|
// need the "/originalport" suffix to avoid dnat port randomization
|
||||||
|
toDestination += fmt.Sprintf(":%d-%d/%d", rule.TranslatedPort.Values[0], rule.TranslatedPort.Values[1], rule.DestinationPort.Values[0])
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||||
|
}
|
||||||
|
|
||||||
|
proto := strings.ToLower(string(rule.Protocol))
|
||||||
|
|
||||||
|
rules := make(map[string]ruleInfo, 3)
|
||||||
|
|
||||||
|
// DNAT rule
|
||||||
|
dnatRule := []string{
|
||||||
|
"!", "-i", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-j", "DNAT",
|
||||||
|
"--to-destination", toDestination,
|
||||||
|
}
|
||||||
|
dnatRule = append(dnatRule, applyPort("--dport", &rule.DestinationPort)...)
|
||||||
|
rules[ruleKey+dnatSuffix] = ruleInfo{
|
||||||
|
table: tableNat,
|
||||||
|
chain: chainRTRDR,
|
||||||
|
rule: dnatRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// SNAT rule
|
||||||
|
snatRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-d", rule.TranslatedAddress.String(),
|
||||||
|
"-j", "MASQUERADE",
|
||||||
|
}
|
||||||
|
snatRule = append(snatRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||||
|
rules[ruleKey+snatSuffix] = ruleInfo{
|
||||||
|
table: tableNat,
|
||||||
|
chain: chainRTNAT,
|
||||||
|
rule: snatRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Forward filtering rule, if fwd policy is DROP
|
||||||
|
forwardRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-p", proto,
|
||||||
|
"-d", rule.TranslatedAddress.String(),
|
||||||
|
"-j", "ACCEPT",
|
||||||
|
}
|
||||||
|
forwardRule = append(forwardRule, applyPort("--dport", &rule.TranslatedPort)...)
|
||||||
|
rules[ruleKey+fwdSuffix] = ruleInfo{
|
||||||
|
table: tableFilter,
|
||||||
|
chain: chainRTFWDOUT,
|
||||||
|
rule: forwardRule,
|
||||||
|
}
|
||||||
|
|
||||||
|
for key, ruleInfo := range rules {
|
||||||
|
if err := r.iptablesClient.Append(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||||
|
if rollbackErr := r.rollbackRules(rules); rollbackErr != nil {
|
||||||
|
log.Errorf("rollback failed: %v", rollbackErr)
|
||||||
|
}
|
||||||
|
return nil, fmt.Errorf("add rule %s: %w", key, err)
|
||||||
|
}
|
||||||
|
r.rules[key] = ruleInfo.rule
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) rollbackRules(rules map[string]ruleInfo) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for key, ruleInfo := range rules {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(ruleInfo.table, ruleInfo.chain, ruleInfo.rule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("rollback rule %s: %w", key, err))
|
||||||
|
// On rollback error, add to rules map for next cleanup
|
||||||
|
r.rules[key] = ruleInfo.rule
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if merr != nil {
|
||||||
|
r.updateState()
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainRTRDR, dnatRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete DNAT rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if snatRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableNat, chainRTNAT, snatRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete SNAT rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwdRule, exists := r.rules[ruleKey+fwdSuffix]; exists {
|
||||||
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, fwdRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete forward rule: %w", err))
|
||||||
|
}
|
||||||
|
delete(r.rules, ruleKey+fwdSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.updateState()
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
||||||
var rule []string
|
var rule []string
|
||||||
|
|
||||||
@@ -590,10 +761,10 @@ func applyPort(flag string, port *firewall.Port) []string {
|
|||||||
if len(port.Values) > 1 {
|
if len(port.Values) > 1 {
|
||||||
portList := make([]string, len(port.Values))
|
portList := make([]string, len(port.Values))
|
||||||
for i, p := range port.Values {
|
for i, p := range port.Values {
|
||||||
portList[i] = strconv.Itoa(p)
|
portList[i] = strconv.Itoa(int(p))
|
||||||
}
|
}
|
||||||
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
return []string{"-m", "multiport", flag, strings.Join(portList, ",")}
|
||||||
}
|
}
|
||||||
|
|
||||||
return []string{flag, strconv.Itoa(port.Values[0])}
|
return []string{flag, strconv.Itoa(int(port.Values[0]))}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -39,12 +39,14 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Now 5 rules:
|
// Now 5 rules:
|
||||||
// 1. established rule in forward chain
|
// 1. established rule forward in
|
||||||
// 2. jump rule to NAT chain
|
// 2. estbalished rule forward out
|
||||||
// 3. jump rule to PRE chain
|
// 3. jump rule to POST nat chain
|
||||||
// 4. static outbound masquerade rule
|
// 4. jump rule to PRE mangle chain
|
||||||
// 5. static return masquerade rule
|
// 5. jump rule to PRE nat chain
|
||||||
require.Len(t, manager.rules, 5, "should have created rules map")
|
// 6. static outbound masquerade rule
|
||||||
|
// 7. static return masquerade rule
|
||||||
|
require.Len(t, manager.rules, 7, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
@@ -239,7 +241,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{80}},
|
dPort: &firewall.Port{Values: []uint16{80}},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -252,7 +254,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
@@ -285,7 +287,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
@@ -297,7 +299,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -307,8 +309,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||||
dPort: &firewall.Port{Values: []int{22}},
|
dPort: &firewall.Port{Values: []uint16{22}},
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -328,18 +330,18 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
rule, ok := r.rules[ruleKey.ID()]
|
||||||
assert.True(t, ok, "Rule not found in internal map")
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
// Log the internal rule
|
// Log the internal rule
|
||||||
t.Logf("Internal rule: %v", rule)
|
t.Logf("Internal rule: %v", rule)
|
||||||
|
|
||||||
// Check if the rule exists in iptables
|
// Check if the rule exists in iptables
|
||||||
exists, err := iptablesClient.Exists(tableFilter, chainRTFWD, rule...)
|
exists, err := iptablesClient.Exists(tableFilter, chainRTFWDIN, rule...)
|
||||||
assert.NoError(t, err, "Failed to check rule existence")
|
assert.NoError(t, err, "Failed to check rule existence")
|
||||||
assert.True(t, exists, "Rule not found in iptables")
|
assert.True(t, exists, "Rule not found in iptables")
|
||||||
|
|
||||||
|
|||||||
@@ -5,12 +5,13 @@ type Rule struct {
|
|||||||
ruleID string
|
ruleID string
|
||||||
ipsetName string
|
ipsetName string
|
||||||
|
|
||||||
specs []string
|
specs []string
|
||||||
ip string
|
mangleSpecs []string
|
||||||
chain string
|
ip string
|
||||||
|
chain string
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) ID() string {
|
||||||
return r.ruleID
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -37,6 +37,11 @@ func (s *ipList) UnmarshalJSON(data []byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.ips = temp.IPs
|
s.ips = temp.IPs
|
||||||
|
|
||||||
|
if temp.IPs == nil {
|
||||||
|
temp.IPs = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -89,5 +94,10 @@ func (s *ipsetStore) UnmarshalJSON(data []byte) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
s.ipsets = temp.IPSets
|
s.ipsets = temp.IPSets
|
||||||
|
|
||||||
|
if temp.IPSets == nil {
|
||||||
|
temp.IPSets = make(map[string]*ipList)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,21 +4,20 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress iface.WGAddress `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Name() string {
|
func (i *InterfaceState) Name() string {
|
||||||
return i.NameStr
|
return i.NameStr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Address() device.WGAddress {
|
func (i *InterfaceState) Address() wgaddr.Address {
|
||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -62,7 +61,7 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
ipt.aclMgr.ipsetStore = s.ACLIPsetStore
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := ipt.Reset(nil); err != nil {
|
if err := ipt.Close(nil); err != nil {
|
||||||
return fmt.Errorf("reset iptables manager: %w", err)
|
return fmt.Errorf("reset iptables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -26,8 +26,8 @@ const (
|
|||||||
// Each firewall type for different OS can use different type
|
// Each firewall type for different OS can use different type
|
||||||
// of the properties to hold data of the created rule
|
// of the properties to hold data of the created rule
|
||||||
type Rule interface {
|
type Rule interface {
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
GetRuleID() string
|
ID() string
|
||||||
}
|
}
|
||||||
|
|
||||||
// RuleDirection is the traffic direction which a rule is applied
|
// RuleDirection is the traffic direction which a rule is applied
|
||||||
@@ -65,14 +65,13 @@ type Manager interface {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
AddPeerFiltering(
|
AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
dPort *Port,
|
dPort *Port,
|
||||||
direction RuleDirection,
|
|
||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]Rule, error)
|
) ([]Rule, error)
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -81,7 +80,15 @@ type Manager interface {
|
|||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto Protocol,
|
||||||
|
sPort *Port,
|
||||||
|
dPort *Port,
|
||||||
|
action Action,
|
||||||
|
) (Rule, error)
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule
|
// DeleteRouteRule deletes a routing rule
|
||||||
DeleteRouteRule(rule Rule) error
|
DeleteRouteRule(rule Rule) error
|
||||||
@@ -95,11 +102,23 @@ type Manager interface {
|
|||||||
// SetLegacyManagement sets the legacy management mode
|
// SetLegacyManagement sets the legacy management mode
|
||||||
SetLegacyManagement(legacy bool) error
|
SetLegacyManagement(legacy bool) error
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close closes the firewall manager
|
||||||
Reset(stateManager *statemanager.Manager) error
|
Close(stateManager *statemanager.Manager) error
|
||||||
|
|
||||||
// Flush the changes to firewall controller
|
// Flush the changes to firewall controller
|
||||||
Flush() error
|
Flush() error
|
||||||
|
|
||||||
|
SetLogLevel(log.Level)
|
||||||
|
|
||||||
|
EnableRouting() error
|
||||||
|
|
||||||
|
DisableRouting() error
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
AddDNATRule(ForwardRule) (Rule, error)
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
DeleteDNATRule(Rule) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
|
|||||||
27
client/firewall/manager/forward_rule.go
Normal file
27
client/firewall/manager/forward_rule.go
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
)
|
||||||
|
|
||||||
|
// ForwardRule todo figure out better place to this to avoid circular imports
|
||||||
|
type ForwardRule struct {
|
||||||
|
Protocol Protocol
|
||||||
|
DestinationPort Port
|
||||||
|
TranslatedAddress netip.Addr
|
||||||
|
TranslatedPort Port
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ForwardRule) ID() string {
|
||||||
|
id := fmt.Sprintf("%s;%s;%s;%s",
|
||||||
|
r.Protocol,
|
||||||
|
r.DestinationPort.String(),
|
||||||
|
r.TranslatedAddress.String(),
|
||||||
|
r.TranslatedPort.String())
|
||||||
|
return id
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r ForwardRule) String() string {
|
||||||
|
return fmt.Sprintf("protocol: %s, destinationPort: %s, translatedAddress: %s, translatedPort: %s", r.Protocol, r.DestinationPort.String(), r.TranslatedAddress.String(), r.TranslatedPort.String())
|
||||||
|
}
|
||||||
@@ -1,36 +1,37 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
"strconv"
|
"strconv"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Protocol is the protocol of the port
|
|
||||||
type Protocol string
|
|
||||||
|
|
||||||
const (
|
|
||||||
// ProtocolTCP is the TCP protocol
|
|
||||||
ProtocolTCP Protocol = "tcp"
|
|
||||||
|
|
||||||
// ProtocolUDP is the UDP protocol
|
|
||||||
ProtocolUDP Protocol = "udp"
|
|
||||||
|
|
||||||
// ProtocolICMP is the ICMP protocol
|
|
||||||
ProtocolICMP Protocol = "icmp"
|
|
||||||
|
|
||||||
// ProtocolALL cover all supported protocols
|
|
||||||
ProtocolALL Protocol = "all"
|
|
||||||
|
|
||||||
// ProtocolUnknown unknown protocol
|
|
||||||
ProtocolUnknown Protocol = "unknown"
|
|
||||||
)
|
|
||||||
|
|
||||||
// Port of the address for firewall rule
|
// Port of the address for firewall rule
|
||||||
|
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||||
type Port struct {
|
type Port struct {
|
||||||
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
// IsRange is true Values contains two values, the first is the start port, the second is the end port
|
||||||
IsRange bool
|
IsRange bool
|
||||||
|
|
||||||
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
// Values contains one value for single port, multiple values for the list of ports, or two values for the range of ports
|
||||||
Values []int
|
Values []uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewPort(ports ...int) (*Port, error) {
|
||||||
|
if len(ports) == 0 {
|
||||||
|
return nil, fmt.Errorf("no port provided")
|
||||||
|
}
|
||||||
|
|
||||||
|
ports16 := make([]uint16, len(ports))
|
||||||
|
for i, port := range ports {
|
||||||
|
if port < 1 || port > 65535 {
|
||||||
|
return nil, fmt.Errorf("invalid port number: %d (must be between 1-65535)", port)
|
||||||
|
}
|
||||||
|
ports16[i] = uint16(port)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &Port{
|
||||||
|
IsRange: len(ports) > 1,
|
||||||
|
Values: ports16,
|
||||||
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// String interface implementation
|
// String interface implementation
|
||||||
@@ -40,7 +41,11 @@ func (p *Port) String() string {
|
|||||||
if ports != "" {
|
if ports != "" {
|
||||||
ports += ","
|
ports += ","
|
||||||
}
|
}
|
||||||
ports += strconv.Itoa(port)
|
ports += strconv.Itoa(int(port))
|
||||||
}
|
}
|
||||||
|
if p.IsRange {
|
||||||
|
ports = "range:" + ports
|
||||||
|
}
|
||||||
|
|
||||||
return ports
|
return ports
|
||||||
}
|
}
|
||||||
|
|||||||
19
client/firewall/manager/protocol.go
Normal file
19
client/firewall/manager/protocol.go
Normal file
@@ -0,0 +1,19 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
// Protocol is the protocol of the port
|
||||||
|
// todo Move Protocol and Port and RouterPair to the Firwall package or a separate package
|
||||||
|
type Protocol string
|
||||||
|
|
||||||
|
const (
|
||||||
|
// ProtocolTCP is the TCP protocol
|
||||||
|
ProtocolTCP Protocol = "tcp"
|
||||||
|
|
||||||
|
// ProtocolUDP is the UDP protocol
|
||||||
|
ProtocolUDP Protocol = "udp"
|
||||||
|
|
||||||
|
// ProtocolICMP is the ICMP protocol
|
||||||
|
ProtocolICMP Protocol = "icmp"
|
||||||
|
|
||||||
|
// ProtocolALL cover all supported protocols
|
||||||
|
ProtocolALL Protocol = "all"
|
||||||
|
)
|
||||||
@@ -2,10 +2,9 @@ package nftables
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"slices"
|
||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
@@ -23,12 +22,10 @@ import (
|
|||||||
const (
|
const (
|
||||||
|
|
||||||
// rules chains contains the effective ACL rules
|
// rules chains contains the effective ACL rules
|
||||||
chainNameInputRules = "netbird-acl-input-rules"
|
chainNameInputRules = "netbird-acl-input-rules"
|
||||||
chainNameOutputRules = "netbird-acl-output-rules"
|
|
||||||
|
|
||||||
// filter chains contains the rules that jump to the rules chains
|
// filter chains contains the rules that jump to the rules chains
|
||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
chainNameOutputFilter = "netbird-acl-output-filter"
|
|
||||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||||
chainNamePrerouting = "netbird-rt-prerouting"
|
chainNamePrerouting = "netbird-rt-prerouting"
|
||||||
|
|
||||||
@@ -47,9 +44,9 @@ type AclManager struct {
|
|||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
routingFwChainName string
|
routingFwChainName string
|
||||||
|
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
chainInputRules *nftables.Chain
|
chainInputRules *nftables.Chain
|
||||||
chainOutputRules *nftables.Chain
|
chainPrerouting *nftables.Chain
|
||||||
|
|
||||||
ipsetStore *ipsetStore
|
ipsetStore *ipsetStore
|
||||||
rules map[string]*Rule
|
rules map[string]*Rule
|
||||||
@@ -87,14 +84,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *AclManager) AddPeerFiltering(
|
func (m *AclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var ipset *nftables.Set
|
var ipset *nftables.Set
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
@@ -106,7 +102,7 @@ func (m *AclManager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRules := make([]firewall.Rule, 0, 2)
|
newRules := make([]firewall.Rule, 0, 2)
|
||||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, direction, action, ipset, comment)
|
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -123,23 +119,32 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if r.nftSet == nil {
|
if r.nftSet == nil {
|
||||||
err := m.rConn.DelRule(r.nftRule)
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
}
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
if r.mangleRule != nil {
|
||||||
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(m.rules, r.ID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
ips, ok := m.ipsetStore.ips(r.nftSet.Name)
|
||||||
if !ok {
|
if !ok {
|
||||||
err := m.rConn.DelRule(r.nftRule)
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
}
|
}
|
||||||
delete(m.rules, r.GetRuleID())
|
if r.mangleRule != nil {
|
||||||
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
delete(m.rules, r.ID())
|
||||||
return m.rConn.Flush()
|
return m.rConn.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := ips[r.ip.String()]; ok {
|
if _, ok := ips[r.ip.String()]; ok {
|
||||||
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
err := m.sConn.SetDeleteElements(r.nftSet, []nftables.SetElement{{Key: r.ip.To4()}})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -158,16 +163,20 @@ func (m *AclManager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
err := m.rConn.DelRule(r.nftRule)
|
if err := m.rConn.DelRule(r.nftRule); err != nil {
|
||||||
if err != nil {
|
|
||||||
log.Errorf("failed to delete rule: %v", err)
|
log.Errorf("failed to delete rule: %v", err)
|
||||||
}
|
}
|
||||||
err = m.rConn.Flush()
|
if r.mangleRule != nil {
|
||||||
if err != nil {
|
if err := m.rConn.DelRule(r.mangleRule); err != nil {
|
||||||
|
log.Errorf("failed to delete mangle rule: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
delete(m.rules, r.GetRuleID())
|
delete(m.rules, r.ID())
|
||||||
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
m.ipsetStore.DeleteReferenceFromIpSet(r.nftSet.Name)
|
||||||
|
|
||||||
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
if m.ipsetStore.HasReferenceToSet(r.nftSet.Name) {
|
||||||
@@ -216,38 +225,6 @@ func (m *AclManager) createDefaultAllowRules() error {
|
|||||||
Exprs: expIn,
|
Exprs: expIn,
|
||||||
})
|
})
|
||||||
|
|
||||||
expOut := []expr.Any{
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
// mask
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 1,
|
|
||||||
DestRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: []byte{0, 0, 0, 0},
|
|
||||||
Xor: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
// net address
|
|
||||||
&expr.Cmp{
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
_ = m.rConn.InsertRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: m.chainOutputRules,
|
|
||||||
Position: 0,
|
|
||||||
Exprs: expOut,
|
|
||||||
})
|
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
@@ -262,25 +239,32 @@ func (m *AclManager) Flush() error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.refreshRuleHandles(m.chainInputRules); err != nil {
|
if err := m.refreshRuleHandles(m.chainInputRules, false); err != nil {
|
||||||
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
log.Errorf("failed to refresh rule handles ipv4 input chain: %v", err)
|
||||||
}
|
}
|
||||||
|
if err := m.refreshRuleHandles(m.chainPrerouting, true); err != nil {
|
||||||
if err := m.refreshRuleHandles(m.chainOutputRules); err != nil {
|
log.Errorf("failed to refresh rule handles prerouting chain: %v", err)
|
||||||
log.Errorf("failed to refresh rule handles IPv4 output chain: %v", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, direction firewall.RuleDirection, action firewall.Action, ipset *nftables.Set, comment string) (*Rule, error) {
|
func (m *AclManager) addIOFiltering(
|
||||||
ruleId := generatePeerRuleId(ip, sPort, dPort, direction, action, ipset)
|
ip net.IP,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
ipset *nftables.Set,
|
||||||
|
) (*Rule, error) {
|
||||||
|
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
if r, ok := m.rules[ruleId]; ok {
|
||||||
return &Rule{
|
return &Rule{
|
||||||
r.nftRule,
|
nftRule: r.nftRule,
|
||||||
r.nftSet,
|
mangleRule: r.mangleRule,
|
||||||
r.ruleID,
|
nftSet: r.nftSet,
|
||||||
ip,
|
ruleID: r.ruleID,
|
||||||
|
ip: ip,
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -312,9 +296,6 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
if !bytes.HasPrefix(anyIP, rawIP) {
|
if !bytes.HasPrefix(anyIP, rawIP) {
|
||||||
// source address position
|
// source address position
|
||||||
addrOffset := uint32(12)
|
addrOffset := uint32(12)
|
||||||
if direction == firewall.RuleDirectionOUT {
|
|
||||||
addrOffset += 4 // is ipv4 address length
|
|
||||||
}
|
|
||||||
|
|
||||||
expressions = append(expressions,
|
expressions = append(expressions,
|
||||||
&expr.Payload{
|
&expr.Payload{
|
||||||
@@ -344,73 +325,100 @@ func (m *AclManager) addIOFiltering(ip net.IP, proto firewall.Protocol, sPort *f
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if sPort != nil && len(sPort.Values) != 0 {
|
expressions = append(expressions, applyPort(sPort, true)...)
|
||||||
expressions = append(expressions,
|
expressions = append(expressions, applyPort(dPort, false)...)
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 0,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*sPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
if dPort != nil && len(dPort.Values) != 0 {
|
mainExpressions := slices.Clone(expressions)
|
||||||
expressions = append(expressions,
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseTransportHeader,
|
|
||||||
Offset: 2,
|
|
||||||
Len: 2,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: encodePort(*dPort),
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
switch action {
|
switch action {
|
||||||
case firewall.ActionAccept:
|
case firewall.ActionAccept:
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictAccept})
|
||||||
case firewall.ActionDrop:
|
case firewall.ActionDrop:
|
||||||
expressions = append(expressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
userData := []byte(ruleId)
|
||||||
|
|
||||||
var chain *nftables.Chain
|
chain := m.chainInputRules
|
||||||
if direction == firewall.RuleDirectionIN {
|
|
||||||
chain = m.chainInputRules
|
|
||||||
} else {
|
|
||||||
chain = m.chainOutputRules
|
|
||||||
}
|
|
||||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Chain: chain,
|
Chain: chain,
|
||||||
Exprs: expressions,
|
Exprs: mainExpressions,
|
||||||
UserData: userData,
|
UserData: userData,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
if err := m.rConn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
rule := &Rule{
|
rule := &Rule{
|
||||||
nftRule: nftRule,
|
nftRule: nftRule,
|
||||||
nftSet: ipset,
|
mangleRule: m.createPreroutingRule(expressions, userData),
|
||||||
ruleID: ruleId,
|
nftSet: ipset,
|
||||||
ip: ip,
|
ruleID: ruleId,
|
||||||
|
ip: ip,
|
||||||
}
|
}
|
||||||
m.rules[ruleId] = rule
|
m.rules[ruleId] = rule
|
||||||
if ipset != nil {
|
if ipset != nil {
|
||||||
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
m.ipsetStore.AddReferenceToIpset(ipset.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
return rule, nil
|
return rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (m *AclManager) createPreroutingRule(expressions []expr.Any, userData []byte) *nftables.Rule {
|
||||||
|
if m.chainPrerouting == nil {
|
||||||
|
log.Warn("prerouting chain is not created")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
preroutingExprs := slices.Clone(expressions)
|
||||||
|
|
||||||
|
// interface
|
||||||
|
preroutingExprs = append([]expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(m.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}, preroutingExprs...)
|
||||||
|
|
||||||
|
// local destination and mark
|
||||||
|
preroutingExprs = append(preroutingExprs,
|
||||||
|
&expr.Fib{
|
||||||
|
Register: 1,
|
||||||
|
ResultADDRTYPE: true,
|
||||||
|
FlagDADDR: true,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
||||||
|
},
|
||||||
|
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
|
},
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return m.rConn.AddRule(&nftables.Rule{
|
||||||
|
Table: m.workTable,
|
||||||
|
Chain: m.chainPrerouting,
|
||||||
|
Exprs: preroutingExprs,
|
||||||
|
UserData: userData,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
func (m *AclManager) createDefaultChains() (err error) {
|
func (m *AclManager) createDefaultChains() (err error) {
|
||||||
// chainNameInputRules
|
// chainNameInputRules
|
||||||
chain := m.createChain(chainNameInputRules)
|
chain := m.createChain(chainNameInputRules)
|
||||||
@@ -421,15 +429,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
}
|
}
|
||||||
m.chainInputRules = chain
|
m.chainInputRules = chain
|
||||||
|
|
||||||
// chainNameOutputRules
|
|
||||||
chain = m.createChain(chainNameOutputRules)
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chainNameOutputRules, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
m.chainOutputRules = chain
|
|
||||||
|
|
||||||
// netbird-acl-input-filter
|
// netbird-acl-input-filter
|
||||||
// type filter hook input priority filter; policy accept;
|
// type filter hook input priority filter; policy accept;
|
||||||
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
chain = m.createFilterChainWithHook(chainNameInputFilter, nftables.ChainHookInput)
|
||||||
@@ -441,18 +440,6 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// netbird-acl-output-filter
|
|
||||||
// type filter hook output priority filter; policy accept;
|
|
||||||
chain = m.createFilterChainWithHook(chainNameOutputFilter, nftables.ChainHookOutput)
|
|
||||||
m.addFwdAllow(chain, expr.MetaKeyOIFNAME)
|
|
||||||
m.addJumpRule(chain, m.chainOutputRules.Name, expr.MetaKeyOIFNAME) // to netbird-acl-output-rules
|
|
||||||
m.addDropExpressions(chain, expr.MetaKeyOIFNAME)
|
|
||||||
err = m.rConn.Flush()
|
|
||||||
if err != nil {
|
|
||||||
log.Debugf("failed to create chain (%s): %s", chainNameOutputFilter, err)
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
// netbird-acl-forward-filter
|
// netbird-acl-forward-filter
|
||||||
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
chainFwFilter := m.createFilterChainWithHook(chainNameForwardFilter, nftables.ChainHookForward)
|
||||||
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
m.addJumpRulesToRtForward(chainFwFilter) // to netbird-rt-fwd
|
||||||
@@ -475,7 +462,7 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||||
// netbird peer IP.
|
// netbird peer IP.
|
||||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||||
preroutingChain := m.rConn.AddChain(&nftables.Chain{
|
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
||||||
Name: chainNamePrerouting,
|
Name: chainNamePrerouting,
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
@@ -483,8 +470,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
|||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
})
|
})
|
||||||
|
|
||||||
m.addPreroutingRule(preroutingChain)
|
|
||||||
|
|
||||||
m.addFwmarkToForward(chainFwFilter)
|
m.addFwmarkToForward(chainFwFilter)
|
||||||
|
|
||||||
if err := m.rConn.Flush(); err != nil {
|
if err := m.rConn.Flush(); err != nil {
|
||||||
@@ -494,43 +479,6 @@ func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addPreroutingRule(preroutingChain *nftables.Chain) {
|
|
||||||
m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: m.workTable,
|
|
||||||
Chain: preroutingChain,
|
|
||||||
Exprs: []expr.Any{
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyIIFNAME,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Fib{
|
|
||||||
Register: 1,
|
|
||||||
ResultADDRTYPE: true,
|
|
||||||
FlagDADDR: true,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(unix.RTN_LOCAL),
|
|
||||||
},
|
|
||||||
&expr.Immediate{
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
|
||||||
},
|
|
||||||
&expr.Meta{
|
|
||||||
Key: expr.MetaKeyMARK,
|
|
||||||
Register: 1,
|
|
||||||
SourceRegister: true,
|
|
||||||
},
|
|
||||||
},
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
||||||
m.rConn.InsertRule(&nftables.Rule{
|
m.rConn.InsertRule(&nftables.Rule{
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
@@ -546,8 +494,7 @@ func (m *AclManager) addFwmarkToForward(chainFwFilter *nftables.Chain) {
|
|||||||
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.PreroutingFwmarkRedirected),
|
||||||
},
|
},
|
||||||
&expr.Verdict{
|
&expr.Verdict{
|
||||||
Kind: expr.VerdictJump,
|
Kind: expr.VerdictAccept,
|
||||||
Chain: m.chainInputRules.Name,
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
})
|
})
|
||||||
@@ -619,45 +566,6 @@ func (m *AclManager) addDropExpressions(chain *nftables.Chain, ifaceKey expr.Met
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) addFwdAllow(chain *nftables.Chain, iifname expr.MetaKey) {
|
|
||||||
ip, _ := netip.AddrFromSlice(m.wgIface.Address().Network.IP.To4())
|
|
||||||
dstOp := expr.CmpOpNeq
|
|
||||||
expressions := []expr.Any{
|
|
||||||
&expr.Meta{Key: iifname, Register: 1},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpEq,
|
|
||||||
Register: 1,
|
|
||||||
Data: ifname(m.wgIface.Name()),
|
|
||||||
},
|
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 2,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 16,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 2,
|
|
||||||
DestRegister: 2,
|
|
||||||
Len: 4,
|
|
||||||
Xor: []byte{0x0, 0x0, 0x0, 0x0},
|
|
||||||
Mask: m.wgIface.Address().Network.Mask,
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: dstOp,
|
|
||||||
Register: 2,
|
|
||||||
Data: ip.Unmap().AsSlice(),
|
|
||||||
},
|
|
||||||
&expr.Verdict{
|
|
||||||
Kind: expr.VerdictAccept,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
_ = m.rConn.AddRule(&nftables.Rule{
|
|
||||||
Table: chain.Table,
|
|
||||||
Chain: chain,
|
|
||||||
Exprs: expressions,
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
func (m *AclManager) addJumpRule(chain *nftables.Chain, to string, ifaceKey expr.MetaKey) {
|
||||||
expressions := []expr.Any{
|
expressions := []expr.Any{
|
||||||
&expr.Meta{Key: ifaceKey, Register: 1},
|
&expr.Meta{Key: ifaceKey, Register: 1},
|
||||||
@@ -733,6 +641,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
|||||||
for i := 0; ; i++ {
|
for i := 0; ; i++ {
|
||||||
err = m.rConn.Flush()
|
err = m.rConn.Flush()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Debugf("failed to flush nftables: %v", err)
|
||||||
if !strings.Contains(err.Error(), "busy") {
|
if !strings.Contains(err.Error(), "busy") {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -749,7 +658,7 @@ func (m *AclManager) flushWithBackoff() (err error) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
func (m *AclManager) refreshRuleHandles(chain *nftables.Chain, mangle bool) error {
|
||||||
if m.workTable == nil || chain == nil {
|
if m.workTable == nil || chain == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -766,22 +675,19 @@ func (m *AclManager) refreshRuleHandles(chain *nftables.Chain) error {
|
|||||||
split := bytes.Split(rule.UserData, []byte(" "))
|
split := bytes.Split(rule.UserData, []byte(" "))
|
||||||
r, ok := m.rules[string(split[0])]
|
r, ok := m.rules[string(split[0])]
|
||||||
if ok {
|
if ok {
|
||||||
*r.nftRule = *rule
|
if mangle {
|
||||||
|
*r.mangleRule = *rule
|
||||||
|
} else {
|
||||||
|
*r.nftRule = *rule
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func generatePeerRuleId(
|
func generatePeerRuleId(ip net.IP, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action, ipset *nftables.Set) string {
|
||||||
ip net.IP,
|
rulesetID := ":"
|
||||||
sPort *firewall.Port,
|
|
||||||
dPort *firewall.Port,
|
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
|
||||||
ipset *nftables.Set,
|
|
||||||
) string {
|
|
||||||
rulesetID := ":" + strconv.Itoa(int(direction)) + ":"
|
|
||||||
if sPort != nil {
|
if sPort != nil {
|
||||||
rulesetID += sPort.String()
|
rulesetID += sPort.String()
|
||||||
}
|
}
|
||||||
@@ -797,12 +703,6 @@ func generatePeerRuleId(
|
|||||||
return "set:" + ipset.Name + rulesetID
|
return "set:" + ipset.Name + rulesetID
|
||||||
}
|
}
|
||||||
|
|
||||||
func encodePort(port firewall.Port) []byte {
|
|
||||||
bs := make([]byte, 2)
|
|
||||||
binary.BigEndian.PutUint16(bs, uint16(port.Values[0]))
|
|
||||||
return bs
|
|
||||||
}
|
|
||||||
|
|
||||||
func ifname(n string) []byte {
|
func ifname(n string) []byte {
|
||||||
b := make([]byte, 16)
|
b := make([]byte, 16)
|
||||||
copy(b, n+"\x00")
|
copy(b, n+"\x00")
|
||||||
|
|||||||
@@ -14,7 +14,7 @@ import (
|
|||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ const (
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMapper interface {
|
type iFaceMapper interface {
|
||||||
Name() string
|
Name() string
|
||||||
Address() iface.WGAddress
|
Address() wgaddr.Address
|
||||||
IsUserspaceBind() bool
|
IsUserspaceBind() bool
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -87,7 +87,7 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
// We only need to record minimal interface state for potential recreation.
|
// We only need to record minimal interface state for potential recreation.
|
||||||
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
// Unlike iptables, which requires tracking individual rules, nftables maintains
|
||||||
// a known state (our netbird table plus a few static rules). This allows for easy
|
// a known state (our netbird table plus a few static rules). This allows for easy
|
||||||
// cleanup using Reset() without needing to store specific rules.
|
// cleanup using Close() without needing to store specific rules.
|
||||||
if err := stateManager.UpdateState(&ShutdownState{
|
if err := stateManager.UpdateState(&ShutdownState{
|
||||||
InterfaceState: &InterfaceState{
|
InterfaceState: &InterfaceState{
|
||||||
NameStr: m.wgIface.Name(),
|
NameStr: m.wgIface.Name(),
|
||||||
@@ -113,14 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
direction firewall.RuleDirection,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -130,10 +129,18 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, direction, action, ipsetName, comment)
|
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Prefix, proto firewall.Protocol, sPort *firewall.Port, dPort *firewall.Port, action firewall.Action) (firewall.Rule, error) {
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto firewall.Protocol,
|
||||||
|
sPort *firewall.Port,
|
||||||
|
dPort *firewall.Port,
|
||||||
|
action firewall.Action,
|
||||||
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -141,7 +148,7 @@ func (m *Manager) AddRouteFiltering(sources []netip.Prefix, destination netip.Pr
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -236,7 +243,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
@@ -312,6 +319,19 @@ func (m *Manager) cleanupNetbirdTables() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLogLevel sets the log level for the firewall manager
|
||||||
|
func (m *Manager) SetLogLevel(log.Level) {
|
||||||
|
// not supported
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) EnableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) DisableRouting() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
// Flush rule/chain/set operations from the buffer
|
// Flush rule/chain/set operations from the buffer
|
||||||
//
|
//
|
||||||
// Method also get all rules after flush and refreshes handle values in the rulesets
|
// Method also get all rules after flush and refreshes handle values in the rulesets
|
||||||
@@ -323,6 +343,22 @@ func (m *Manager) Flush() error {
|
|||||||
return m.aclManager.Flush()
|
return m.aclManager.Flush()
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// AddDNATRule adds a DNAT rule
|
||||||
|
func (m *Manager) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.AddDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
|
func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.DeleteDNATRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package nftables
|
package nftables
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"bytes"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
"os/exec"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -14,15 +16,15 @@ import (
|
|||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
var ifaceMock = &iFaceMock{
|
var ifaceMock = &iFaceMock{
|
||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: net.ParseIP("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
IP: net.ParseIP("100.96.0.0"),
|
||||||
@@ -35,7 +37,7 @@ var ifaceMock = &iFaceMock{
|
|||||||
// iFaceMapper defines subset methods of interface required for manager
|
// iFaceMapper defines subset methods of interface required for manager
|
||||||
type iFaceMock struct {
|
type iFaceMock struct {
|
||||||
NameFunc func() string
|
NameFunc func() string
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Name() string {
|
func (i *iFaceMock) Name() string {
|
||||||
@@ -45,7 +47,7 @@ func (i *iFaceMock) Name() string {
|
|||||||
panic("NameFunc is not set")
|
panic("NameFunc is not set")
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *iFaceMock) Address() iface.WGAddress {
|
func (i *iFaceMock) Address() wgaddr.Address {
|
||||||
if i.AddressFunc != nil {
|
if i.AddressFunc != nil {
|
||||||
return i.AddressFunc()
|
return i.AddressFunc()
|
||||||
}
|
}
|
||||||
@@ -63,7 +65,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
err = manager.Reset(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
}()
|
}()
|
||||||
@@ -72,16 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(
|
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
ip,
|
|
||||||
fw.ProtocolTCP,
|
|
||||||
nil,
|
|
||||||
&fw.Port{Values: []int{53}},
|
|
||||||
fw.RuleDirectionIN,
|
|
||||||
fw.ActionDrop,
|
|
||||||
"",
|
|
||||||
"",
|
|
||||||
)
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -114,7 +107,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
Kind: expr.VerdictAccept,
|
Kind: expr.VerdictAccept,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
require.ElementsMatch(t, rules[0].Exprs, expectedExprs1, "expected the same expressions")
|
compareExprsIgnoringCounters(t, rules[0].Exprs, expectedExprs1)
|
||||||
|
|
||||||
ipToAdd, _ := netip.AddrFromSlice(ip)
|
ipToAdd, _ := netip.AddrFromSlice(ip)
|
||||||
add := ipToAdd.Unmap()
|
add := ipToAdd.Unmap()
|
||||||
@@ -169,7 +162,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
// established rule remains
|
// established rule remains
|
||||||
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
require.Len(t, rules, 1, "expected 1 rules after deletion")
|
||||||
|
|
||||||
err = manager.Reset(nil)
|
err = manager.Close(nil)
|
||||||
require.NoError(t, err, "failed to reset")
|
require.NoError(t, err, "failed to reset")
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -178,8 +171,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
NameFunc: func() string {
|
NameFunc: func() string {
|
||||||
return "lo"
|
return "lo"
|
||||||
},
|
},
|
||||||
AddressFunc: func() iface.WGAddress {
|
AddressFunc: func() wgaddr.Address {
|
||||||
return iface.WGAddress{
|
return wgaddr.Address{
|
||||||
IP: net.ParseIP("100.96.0.1"),
|
IP: net.ParseIP("100.96.0.1"),
|
||||||
Network: &net.IPNet{
|
Network: &net.IPNet{
|
||||||
IP: net.ParseIP("100.96.0.0"),
|
IP: net.ParseIP("100.96.0.0"),
|
||||||
@@ -198,7 +191,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
time.Sleep(time.Second * 3)
|
time.Sleep(time.Second * 3)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(nil); err != nil {
|
if err := manager.Close(nil); err != nil {
|
||||||
t.Errorf("clear the manager state: %v", err)
|
t.Errorf("clear the manager state: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -207,12 +200,8 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
if i%2 == 0 {
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -225,3 +214,112 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func runIptablesSave(t *testing.T) (string, string) {
|
||||||
|
t.Helper()
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd := exec.Command("iptables-save")
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
err := cmd.Run()
|
||||||
|
require.NoError(t, err, "iptables-save failed to run")
|
||||||
|
|
||||||
|
return stdout.String(), stderr.String()
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyIptablesOutput(t *testing.T, stdout, stderr string) {
|
||||||
|
t.Helper()
|
||||||
|
// Check for any incompatibility warnings
|
||||||
|
require.NotContains(t,
|
||||||
|
stderr,
|
||||||
|
"incompatible",
|
||||||
|
"iptables-save produced compatibility warning. Full stderr: %s",
|
||||||
|
stderr,
|
||||||
|
)
|
||||||
|
|
||||||
|
// Verify standard tables are present
|
||||||
|
expectedTables := []string{
|
||||||
|
"*filter",
|
||||||
|
"*nat",
|
||||||
|
"*mangle",
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, table := range expectedTables {
|
||||||
|
require.Contains(t,
|
||||||
|
stdout,
|
||||||
|
table,
|
||||||
|
"iptables-save output missing expected table: %s\nFull stdout: %s",
|
||||||
|
table,
|
||||||
|
stdout,
|
||||||
|
)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
||||||
|
if check() != NFTABLES {
|
||||||
|
t.Skip("nftables not supported on this system")
|
||||||
|
}
|
||||||
|
|
||||||
|
if _, err := exec.LookPath("iptables-save"); err != nil {
|
||||||
|
t.Skipf("iptables-save not available on this system: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// First ensure iptables-nft tables exist by running iptables-save
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock)
|
||||||
|
require.NoError(t, err, "failed to create manager")
|
||||||
|
require.NoError(t, manager.Init(nil))
|
||||||
|
|
||||||
|
t.Cleanup(func() {
|
||||||
|
err := manager.Close(nil)
|
||||||
|
require.NoError(t, err, "failed to reset manager state")
|
||||||
|
|
||||||
|
// Verify iptables output after reset
|
||||||
|
stdout, stderr := runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
})
|
||||||
|
|
||||||
|
ip := net.ParseIP("100.96.0.1")
|
||||||
|
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
|
_, err = manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
&fw.Port{Values: []uint16{443}},
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|
||||||
|
pair := fw.RouterPair{
|
||||||
|
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
Masquerade: true,
|
||||||
|
}
|
||||||
|
err = manager.AddNatRule(pair)
|
||||||
|
require.NoError(t, err, "failed to add NAT rule")
|
||||||
|
|
||||||
|
stdout, stderr = runIptablesSave(t)
|
||||||
|
verifyIptablesOutput(t, stdout, stderr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func compareExprsIgnoringCounters(t *testing.T, got, want []expr.Any) {
|
||||||
|
t.Helper()
|
||||||
|
require.Equal(t, len(got), len(want), "expression count mismatch")
|
||||||
|
|
||||||
|
for i := range got {
|
||||||
|
if _, isCounter := got[i].(*expr.Counter); isCounter {
|
||||||
|
_, wantIsCounter := want[i].(*expr.Counter)
|
||||||
|
require.True(t, wantIsCounter, "expected Counter at index %d", i)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
require.Equal(t, got[i], want[i], "expression mismatch at index %d", i)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -14,23 +14,31 @@ import (
|
|||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
|
"github.com/google/nftables/xt"
|
||||||
"github.com/hashicorp/go-multierror"
|
"github.com/hashicorp/go-multierror"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
chainNameRoutingFw = "netbird-rt-fwd"
|
tableNat = "nat"
|
||||||
chainNameRoutingNat = "netbird-rt-postrouting"
|
chainNameNatPrerouting = "PREROUTING"
|
||||||
chainNameForward = "FORWARD"
|
chainNameRoutingFw = "netbird-rt-fwd"
|
||||||
|
chainNameRoutingNat = "netbird-rt-postrouting"
|
||||||
|
chainNameRoutingRdr = "netbird-rt-redirect"
|
||||||
|
chainNameForward = "FORWARD"
|
||||||
|
|
||||||
userDataAcceptForwardRuleIif = "frwacceptiif"
|
userDataAcceptForwardRuleIif = "frwacceptiif"
|
||||||
userDataAcceptForwardRuleOif = "frwacceptoif"
|
userDataAcceptForwardRuleOif = "frwacceptoif"
|
||||||
|
|
||||||
|
dnatSuffix = "_dnat"
|
||||||
|
snatSuffix = "_snat"
|
||||||
)
|
)
|
||||||
|
|
||||||
const refreshRulesMapError = "refresh rules map: %w"
|
const refreshRulesMapError = "refresh rules map: %w"
|
||||||
@@ -49,16 +57,18 @@ type router struct {
|
|||||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
legacyManagement bool
|
legacyManagement bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
func newRouter(workTable *nftables.Table, wgIface iFaceMapper) (*router, error) {
|
||||||
r := &router{
|
r := &router{
|
||||||
conn: &nftables.Conn{},
|
conn: &nftables.Conn{},
|
||||||
workTable: workTable,
|
workTable: workTable,
|
||||||
chains: make(map[string]*nftables.Chain),
|
chains: make(map[string]*nftables.Chain),
|
||||||
rules: make(map[string]*nftables.Rule),
|
rules: make(map[string]*nftables.Rule),
|
||||||
wgIface: wgIface,
|
wgIface: wgIface,
|
||||||
|
ipFwdState: ipfwdstate.NewIPForwardingState(),
|
||||||
}
|
}
|
||||||
|
|
||||||
r.ipsetCounter = refcounter.New(
|
r.ipsetCounter = refcounter.New(
|
||||||
@@ -98,7 +108,52 @@ func (r *router) Reset() error {
|
|||||||
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
// clear without deleting the ipsets, the nf table will be deleted by the caller
|
||||||
r.ipsetCounter.Clear()
|
r.ipsetCounter.Clear()
|
||||||
|
|
||||||
return r.removeAcceptForwardRules()
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
if err := r.removeAcceptForwardRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove accept forward rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.removeNatPreroutingRules(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove filter prerouting rules: %w", err))
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) removeNatPreroutingRules() error {
|
||||||
|
table := &nftables.Table{
|
||||||
|
Name: tableNat,
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
}
|
||||||
|
chain := &nftables.Chain{
|
||||||
|
Name: chainNameNatPrerouting,
|
||||||
|
Table: table,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
}
|
||||||
|
rules, err := r.conn.GetRules(table, chain)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get rules from nat table: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
|
||||||
|
// Delete rules that have our UserData suffix
|
||||||
|
for _, rule := range rules {
|
||||||
|
if len(rule.UserData) == 0 || !strings.HasSuffix(string(rule.UserData), dnatSuffix) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete rule %s: %w", rule.UserData, err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
@@ -133,14 +188,22 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameRoutingRdr] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameRoutingRdr,
|
||||||
|
Table: r.workTable,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
})
|
||||||
|
|
||||||
// Chain is created by acl manager
|
// Chain is created by acl manager
|
||||||
// TODO: move creation to a common place
|
// TODO: move creation to a common place
|
||||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
r.chains[chainNamePrerouting] = &nftables.Chain{
|
||||||
Name: chainNamePrerouting,
|
Name: chainNamePrerouting,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
// Add the single NAT rule that matches on mark
|
||||||
@@ -165,6 +228,7 @@ func (r *router) createContainers() error {
|
|||||||
|
|
||||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -173,7 +237,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
|
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
@@ -233,7 +297,13 @@ func (r *router) AddRouteFiltering(
|
|||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
}
|
}
|
||||||
|
|
||||||
rule = r.conn.AddRule(rule)
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
|
if action == firewall.ActionDrop {
|
||||||
|
// TODO: Insert after the established rule
|
||||||
|
rule = r.conn.InsertRule(rule)
|
||||||
|
} else {
|
||||||
|
rule = r.conn.AddRule(rule)
|
||||||
|
}
|
||||||
|
|
||||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
@@ -275,7 +345,7 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleKey := rule.GetRuleID()
|
ruleKey := rule.ID()
|
||||||
nftRule, exists := r.rules[ruleKey]
|
nftRule, exists := r.rules[ruleKey]
|
||||||
if !exists {
|
if !exists {
|
||||||
log.Debugf("route rule %s not found", ruleKey)
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
@@ -404,6 +474,10 @@ func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
|||||||
|
|
||||||
// AddNatRule appends a nftables rule pair to the nat chain
|
// AddNatRule appends a nftables rule pair to the nat chain
|
||||||
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -830,6 +904,10 @@ func (r *router) removeAcceptForwardRulesIptables(ipt *iptables.IPTables) error
|
|||||||
|
|
||||||
// RemoveNatRule removes the prerouting mark rule
|
// RemoveNatRule removes the prerouting mark rule
|
||||||
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.refreshRulesMap(); err != nil {
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
@@ -890,6 +968,269 @@ func (r *router) refreshRulesMap() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *router) AddDNATRule(rule firewall.ForwardRule) (firewall.Rule, error) {
|
||||||
|
if err := r.ipFwdState.RequestForwarding(); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
if _, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
protoNum, err := protoToInt(rule.Protocol)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("convert protocol to number: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.addDnatRedirect(rule, protoNum, ruleKey); err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
|
r.addDnatMasq(rule, protoNum, ruleKey)
|
||||||
|
|
||||||
|
// Unlike iptables, there's no point in adding "out" rules in the forward chain here as our policy is ACCEPT.
|
||||||
|
// To overcome DROP policies in other chains, we'd have to add rules to the chains there.
|
||||||
|
// We also cannot just add "oif <iface> accept" there and filter in our own table as we don't know what is supposed to be allowed.
|
||||||
|
// TODO: find chains with drop policies and add rules there
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush rules: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return &rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addDnatRedirect(rule firewall.ForwardRule, protoNum uint8, ruleKey string) error {
|
||||||
|
dnatExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyIIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseTransportHeader,
|
||||||
|
Offset: 2,
|
||||||
|
Len: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
dnatExprs = append(dnatExprs, applyPort(&rule.DestinationPort, false)...)
|
||||||
|
|
||||||
|
// shifted translated port is not supported in nftables, so we hand this over to xtables
|
||||||
|
if rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2 {
|
||||||
|
if rule.TranslatedPort.Values[0] != rule.DestinationPort.Values[0] ||
|
||||||
|
rule.TranslatedPort.Values[1] != rule.DestinationPort.Values[1] {
|
||||||
|
return r.addXTablesRedirect(dnatExprs, ruleKey, rule)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
additionalExprs, regProtoMin, regProtoMax, err := r.handleTranslatedPort(rule)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
dnatExprs = append(dnatExprs, additionalExprs...)
|
||||||
|
|
||||||
|
dnatExprs = append(dnatExprs,
|
||||||
|
&expr.NAT{
|
||||||
|
Type: expr.NATTypeDestNAT,
|
||||||
|
Family: uint32(nftables.TableFamilyIPv4),
|
||||||
|
RegAddrMin: 1,
|
||||||
|
RegProtoMin: regProtoMin,
|
||||||
|
RegProtoMax: regProtoMax,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingRdr],
|
||||||
|
Exprs: dnatExprs,
|
||||||
|
UserData: []byte(ruleKey + dnatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handleTranslatedPort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
switch {
|
||||||
|
case rule.TranslatedPort.IsRange && len(rule.TranslatedPort.Values) == 2:
|
||||||
|
return r.handlePortRange(rule)
|
||||||
|
case len(rule.TranslatedPort.Values) == 0:
|
||||||
|
return r.handleAddressOnly(rule)
|
||||||
|
case len(rule.TranslatedPort.Values) == 1:
|
||||||
|
return r.handleSinglePort(rule)
|
||||||
|
default:
|
||||||
|
return nil, 0, 0, fmt.Errorf("invalid translated port: %v", rule.TranslatedPort)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handlePortRange(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 3,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[1]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 2, 3, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handleAddressOnly(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 0, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) handleSinglePort(rule firewall.ForwardRule) ([]expr.Any, uint32, uint32, error) {
|
||||||
|
exprs := []expr.Any{
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 2,
|
||||||
|
Data: binaryutil.BigEndian.PutUint16(rule.TranslatedPort.Values[0]),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return exprs, 2, 0, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addXTablesRedirect(dnatExprs []expr.Any, ruleKey string, rule firewall.ForwardRule) error {
|
||||||
|
dnatExprs = append(dnatExprs,
|
||||||
|
&expr.Counter{},
|
||||||
|
&expr.Target{
|
||||||
|
Name: "DNAT",
|
||||||
|
Rev: 2,
|
||||||
|
Info: &xt.NatRange2{
|
||||||
|
NatRange: xt.NatRange{
|
||||||
|
Flags: uint(xt.NatRangeMapIPs | xt.NatRangeProtoSpecified | xt.NatRangeProtoOffset),
|
||||||
|
MinIP: rule.TranslatedAddress.AsSlice(),
|
||||||
|
MaxIP: rule.TranslatedAddress.AsSlice(),
|
||||||
|
MinPort: rule.TranslatedPort.Values[0],
|
||||||
|
MaxPort: rule.TranslatedPort.Values[1],
|
||||||
|
},
|
||||||
|
BasePort: rule.DestinationPort.Values[0],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
dnatRule := &nftables.Rule{
|
||||||
|
Table: &nftables.Table{
|
||||||
|
Name: tableNat,
|
||||||
|
Family: nftables.TableFamilyIPv4,
|
||||||
|
},
|
||||||
|
Chain: &nftables.Chain{
|
||||||
|
Name: chainNameNatPrerouting,
|
||||||
|
Table: r.filterTable,
|
||||||
|
Type: nftables.ChainTypeNAT,
|
||||||
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
|
Priority: nftables.ChainPriorityNATDest,
|
||||||
|
},
|
||||||
|
Exprs: dnatExprs,
|
||||||
|
UserData: []byte(ruleKey + dnatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(dnatRule)
|
||||||
|
r.rules[ruleKey+dnatSuffix] = dnatRule
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) addDnatMasq(rule firewall.ForwardRule, protoNum uint8, ruleKey string) {
|
||||||
|
masqExprs := []expr.Any{
|
||||||
|
&expr.Meta{Key: expr.MetaKeyOIFNAME, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
&expr.Meta{Key: expr.MetaKeyL4PROTO, Register: 1},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{protoNum},
|
||||||
|
},
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: 16,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: rule.TranslatedAddress.AsSlice(),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
masqExprs = append(masqExprs, applyPort(&rule.TranslatedPort, false)...)
|
||||||
|
masqExprs = append(masqExprs, &expr.Masq{})
|
||||||
|
|
||||||
|
masqRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameRoutingNat],
|
||||||
|
Exprs: masqExprs,
|
||||||
|
UserData: []byte(ruleKey + snatSuffix),
|
||||||
|
}
|
||||||
|
r.conn.AddRule(masqRule)
|
||||||
|
r.rules[ruleKey+snatSuffix] = masqRule
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
||||||
|
if err := r.ipFwdState.ReleaseForwarding(); err != nil {
|
||||||
|
log.Errorf("%v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
|
if err := r.refreshRulesMap(); err != nil {
|
||||||
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var merr *multierror.Error
|
||||||
|
if dnatRule, exists := r.rules[ruleKey+dnatSuffix]; exists {
|
||||||
|
if err := r.conn.DelRule(dnatRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete dnat rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if masqRule, exists := r.rules[ruleKey+snatSuffix]; exists {
|
||||||
|
if err := r.conn.DelRule(masqRule); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("delete snat rule: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf(flushError, err))
|
||||||
|
}
|
||||||
|
|
||||||
|
if merr == nil {
|
||||||
|
delete(r.rules, ruleKey+dnatSuffix)
|
||||||
|
delete(r.rules, ruleKey+snatSuffix)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
||||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
||||||
var offset uint32
|
var offset uint32
|
||||||
@@ -953,15 +1294,11 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
if port.IsRange && len(port.Values) == 2 {
|
if port.IsRange && len(port.Values) == 2 {
|
||||||
// Handle port range
|
// Handle port range
|
||||||
exprs = append(exprs,
|
exprs = append(exprs,
|
||||||
&expr.Cmp{
|
&expr.Range{
|
||||||
Op: expr.CmpOpGte,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[0])),
|
FromData: binaryutil.BigEndian.PutUint16(port.Values[0]),
|
||||||
},
|
ToData: binaryutil.BigEndian.PutUint16(port.Values[1]),
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpLte,
|
|
||||||
Register: 1,
|
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(port.Values[1])),
|
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
} else {
|
} else {
|
||||||
@@ -980,7 +1317,7 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
exprs = append(exprs, &expr.Cmp{
|
exprs = append(exprs, &expr.Cmp{
|
||||||
Op: expr.CmpOpEq,
|
Op: expr.CmpOpEq,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
Data: binaryutil.BigEndian.PutUint16(uint16(p)),
|
Data: binaryutil.BigEndian.PutUint16(p),
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -38,7 +38,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
// need fw manager to init both acl mgr and router for all chains to be present
|
// need fw manager to init both acl mgr and router for all chains to be present
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
@@ -127,7 +127,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
t.Run(testCase.Name, func(t *testing.T) {
|
t.Run(testCase.Name, func(t *testing.T) {
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock)
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
require.NoError(t, manager.Reset(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
})
|
})
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NoError(t, manager.Init(nil))
|
require.NoError(t, manager.Init(nil))
|
||||||
@@ -222,7 +222,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{80}},
|
dPort: &firewall.Port{Values: []uint16{80}},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -235,7 +235,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
destination: netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 2048}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 2048}, IsRange: true},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
@@ -268,7 +268,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
sources: []netip.Prefix{netip.MustParsePrefix("172.16.0.0/12")},
|
||||||
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
destination: netip.MustParsePrefix("192.168.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{80, 443, 8080}},
|
sPort: &firewall.Port{Values: []uint16{80, 443, 8080}},
|
||||||
dPort: nil,
|
dPort: nil,
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
@@ -280,7 +280,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
destination: netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
proto: firewall.ProtocolUDP,
|
proto: firewall.ProtocolUDP,
|
||||||
sPort: nil,
|
sPort: nil,
|
||||||
dPort: &firewall.Port{Values: []int{5000, 5100}, IsRange: true},
|
dPort: &firewall.Port{Values: []uint16{5000, 5100}, IsRange: true},
|
||||||
direction: firewall.RuleDirectionIN,
|
direction: firewall.RuleDirectionIN,
|
||||||
action: firewall.ActionDrop,
|
action: firewall.ActionDrop,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -290,8 +290,8 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
sources: []netip.Prefix{netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
destination: netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
proto: firewall.ProtocolTCP,
|
proto: firewall.ProtocolTCP,
|
||||||
sPort: &firewall.Port{Values: []int{1024, 65535}, IsRange: true},
|
sPort: &firewall.Port{Values: []uint16{1024, 65535}, IsRange: true},
|
||||||
dPort: &firewall.Port{Values: []int{22}},
|
dPort: &firewall.Port{Values: []uint16{22}},
|
||||||
direction: firewall.RuleDirectionOUT,
|
direction: firewall.RuleDirectionOUT,
|
||||||
action: firewall.ActionAccept,
|
action: firewall.ActionAccept,
|
||||||
expectSet: false,
|
expectSet: false,
|
||||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -319,7 +319,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
rule, ok := r.rules[ruleKey.GetRuleID()]
|
rule, ok := r.rules[ruleKey.ID()]
|
||||||
assert.True(t, ok, "Rule not found in internal map")
|
assert.True(t, ok, "Rule not found in internal map")
|
||||||
|
|
||||||
t.Log("Internal rule expressions:")
|
t.Log("Internal rule expressions:")
|
||||||
@@ -336,7 +336,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
var nftRule *nftables.Rule
|
var nftRule *nftables.Rule
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if string(rule.UserData) == ruleKey.GetRuleID() {
|
if string(rule.UserData) == ruleKey.ID() {
|
||||||
nftRule = rule
|
nftRule = rule
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
@@ -595,16 +595,20 @@ func containsPort(exprs []expr.Any, port *firewall.Port, isSource bool) bool {
|
|||||||
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
if ex.Base == expr.PayloadBaseTransportHeader && ex.Offset == offset && ex.Len == 2 {
|
||||||
payloadFound = true
|
payloadFound = true
|
||||||
}
|
}
|
||||||
case *expr.Cmp:
|
case *expr.Range:
|
||||||
if port.IsRange {
|
if port.IsRange && len(port.Values) == 2 {
|
||||||
if ex.Op == expr.CmpOpGte || ex.Op == expr.CmpOpLte {
|
fromPort := binary.BigEndian.Uint16(ex.FromData)
|
||||||
|
toPort := binary.BigEndian.Uint16(ex.ToData)
|
||||||
|
if fromPort == port.Values[0] && toPort == port.Values[1] {
|
||||||
portMatchFound = true
|
portMatchFound = true
|
||||||
}
|
}
|
||||||
} else {
|
}
|
||||||
|
case *expr.Cmp:
|
||||||
|
if !port.IsRange {
|
||||||
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
if ex.Op == expr.CmpOpEq && len(ex.Data) == 2 {
|
||||||
portValue := binary.BigEndian.Uint16(ex.Data)
|
portValue := binary.BigEndian.Uint16(ex.Data)
|
||||||
for _, p := range port.Values {
|
for _, p := range port.Values {
|
||||||
if uint16(p) == portValue {
|
if p == portValue {
|
||||||
portMatchFound = true
|
portMatchFound = true
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -8,13 +8,14 @@ import (
|
|||||||
|
|
||||||
// Rule to handle management of rules
|
// Rule to handle management of rules
|
||||||
type Rule struct {
|
type Rule struct {
|
||||||
nftRule *nftables.Rule
|
nftRule *nftables.Rule
|
||||||
nftSet *nftables.Set
|
mangleRule *nftables.Rule
|
||||||
ruleID string
|
nftSet *nftables.Set
|
||||||
ip net.IP
|
ruleID string
|
||||||
|
ip net.IP
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// GetRuleID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *Rule) ID() string {
|
||||||
return r.ruleID
|
return r.ruleID
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1 +0,0 @@
|
|||||||
package nftables
|
|
||||||
@@ -3,21 +3,20 @@ package nftables
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
type InterfaceState struct {
|
type InterfaceState struct {
|
||||||
NameStr string `json:"name"`
|
NameStr string `json:"name"`
|
||||||
WGAddress iface.WGAddress `json:"wg_address"`
|
WGAddress wgaddr.Address `json:"wg_address"`
|
||||||
UserspaceBind bool `json:"userspace_bind"`
|
UserspaceBind bool `json:"userspace_bind"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Name() string {
|
func (i *InterfaceState) Name() string {
|
||||||
return i.NameStr
|
return i.NameStr
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *InterfaceState) Address() device.WGAddress {
|
func (i *InterfaceState) Address() wgaddr.Address {
|
||||||
return i.WGAddress
|
return i.WGAddress
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -39,7 +38,7 @@ func (s *ShutdownState) Cleanup() error {
|
|||||||
return fmt.Errorf("create nftables manager: %w", err)
|
return fmt.Errorf("create nftables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := nft.Reset(nil); err != nil {
|
if err := nft.Close(nil); err != nil {
|
||||||
return fmt.Errorf("reset nftables manager: %w", err)
|
return fmt.Errorf("reset nftables manager: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -2,18 +2,50 @@
|
|||||||
|
|
||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import "github.com/netbirdio/netbird/client/internal/statemanager"
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
fwder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if m.nativeFirewall != nil {
|
if m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.Reset(stateManager)
|
return m.nativeFirewall.Close(stateManager)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,12 +1,16 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -19,12 +23,39 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Reset(*statemanager.Manager) error {
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
|
if m.udpTracker != nil {
|
||||||
|
m.udpTracker.Close()
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.icmpTracker != nil {
|
||||||
|
m.icmpTracker.Close()
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.tcpTracker != nil {
|
||||||
|
m.tcpTracker.Close()
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
|
fwder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.logger != nil {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 2*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := m.logger.Stop(ctx); err != nil {
|
||||||
|
log.Errorf("failed to shutdown logger: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !isWindowsFirewallReachable() {
|
if !isWindowsFirewallReachable() {
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
16
client/firewall/uspfilter/common/iface.go
Normal file
16
client/firewall/uspfilter/common/iface.go
Normal file
@@ -0,0 +1,16 @@
|
|||||||
|
package common
|
||||||
|
|
||||||
|
import (
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
// IFaceMapper defines subset methods of interface required for manager
|
||||||
|
type IFaceMapper interface {
|
||||||
|
SetFilter(device.PacketFilter) error
|
||||||
|
Address() wgaddr.Address
|
||||||
|
GetWGDevice() *wgdevice.Device
|
||||||
|
GetDevice() *device.FilteredDevice
|
||||||
|
}
|
||||||
66
client/firewall/uspfilter/conntrack/common.go
Normal file
66
client/firewall/uspfilter/conntrack/common.go
Normal file
@@ -0,0 +1,66 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// BaseConnTrack provides common fields and locking for all connection types
|
||||||
|
type BaseConnTrack struct {
|
||||||
|
FlowId uuid.UUID
|
||||||
|
Direction nftypes.Direction
|
||||||
|
SourceIP netip.Addr
|
||||||
|
DestIP netip.Addr
|
||||||
|
lastSeen atomic.Int64
|
||||||
|
PacketsTx atomic.Uint64
|
||||||
|
PacketsRx atomic.Uint64
|
||||||
|
BytesTx atomic.Uint64
|
||||||
|
BytesRx atomic.Uint64
|
||||||
|
}
|
||||||
|
|
||||||
|
// these small methods will be inlined by the compiler
|
||||||
|
|
||||||
|
// UpdateLastSeen safely updates the last seen timestamp
|
||||||
|
func (b *BaseConnTrack) UpdateLastSeen() {
|
||||||
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
// UpdateCounters safely updates the packet and byte counters
|
||||||
|
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
|
||||||
|
if direction == nftypes.Egress {
|
||||||
|
b.PacketsTx.Add(1)
|
||||||
|
b.BytesTx.Add(uint64(bytes))
|
||||||
|
} else {
|
||||||
|
b.PacketsRx.Add(1)
|
||||||
|
b.BytesRx.Add(uint64(bytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
|
}
|
||||||
|
|
||||||
|
// timeoutExceeded checks if the connection has exceeded the given timeout
|
||||||
|
func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
||||||
|
lastSeen := time.Unix(0, b.lastSeen.Load())
|
||||||
|
return time.Since(lastSeen) > timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
// ConnKey uniquely identifies a connection
|
||||||
|
type ConnKey struct {
|
||||||
|
SrcIP netip.Addr
|
||||||
|
DstIP netip.Addr
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c ConnKey) String() string {
|
||||||
|
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
|
}
|
||||||
68
client/firewall/uspfilter/conntrack/common_test.go
Normal file
68
client/firewall/uspfilter/conntrack/common_test.go
Normal file
@@ -0,0 +1,68 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
|
)
|
||||||
|
|
||||||
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
|
// Memory pressure tests
|
||||||
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Generate different IPs
|
||||||
|
srcIPs := make([]netip.Addr, 100)
|
||||||
|
dstIPs := make([]netip.Addr, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
srcIdx := i % len(srcIPs)
|
||||||
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Simulate some valid inbound packets
|
||||||
|
if i%3 == 0 {
|
||||||
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Generate different IPs
|
||||||
|
srcIPs := make([]netip.Addr, 100)
|
||||||
|
dstIPs := make([]netip.Addr, 100)
|
||||||
|
for i := 0; i < 100; i++ {
|
||||||
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
srcIdx := i % len(srcIPs)
|
||||||
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
|
||||||
|
|
||||||
|
// Simulate some valid inbound packets
|
||||||
|
if i%3 == 0 {
|
||||||
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
245
client/firewall/uspfilter/conntrack/icmp.go
Normal file
245
client/firewall/uspfilter/conntrack/icmp.go
Normal file
@@ -0,0 +1,245 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultICMPTimeout is the default timeout for ICMP connections
|
||||||
|
DefaultICMPTimeout = 30 * time.Second
|
||||||
|
// ICMPCleanupInterval is how often we check for stale ICMP connections
|
||||||
|
ICMPCleanupInterval = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
|
type ICMPConnKey struct {
|
||||||
|
SrcIP netip.Addr
|
||||||
|
DstIP netip.Addr
|
||||||
|
ID uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i ICMPConnKey) String() string {
|
||||||
|
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
|
type ICMPConnTrack struct {
|
||||||
|
BaseConnTrack
|
||||||
|
ICMPType uint8
|
||||||
|
ICMPCode uint8
|
||||||
|
}
|
||||||
|
|
||||||
|
// ICMPTracker manages ICMP connection states
|
||||||
|
type ICMPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
connections map[ICMPConnKey]*ICMPConnTrack
|
||||||
|
timeout time.Duration
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
tickerCancel context.CancelFunc
|
||||||
|
mutex sync.RWMutex
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultICMPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
tracker := &ICMPTracker{
|
||||||
|
logger: logger,
|
||||||
|
connections: make(map[ICMPConnKey]*ICMPConnTrack),
|
||||||
|
timeout: timeout,
|
||||||
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
|
tickerCancel: cancel,
|
||||||
|
flowLogger: flowLogger,
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.cleanupRoutine(ctx)
|
||||||
|
return tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
|
||||||
|
key := ICMPConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
ID: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound ICMP connection
|
||||||
|
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
|
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
|
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
typ, code := typecode.Type(), typecode.Code()
|
||||||
|
|
||||||
|
// non echo requests don't need tracking
|
||||||
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &ICMPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
|
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||||
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
key := ICMPConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
ID: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
|
defer t.tickerCancel()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.cleanupTicker.C:
|
||||||
|
t.cleanup()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) cleanup() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, conn := range t.connections {
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Debug("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine and releases resources
|
||||||
|
func (t *ICMPTracker) Close() {
|
||||||
|
t.tickerCancel()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections = nil
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
ICMPType: conn.ICMPType,
|
||||||
|
ICMPCode: conn.ICMPCode,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeStart,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: direction,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
fields.RxPackets = 1
|
||||||
|
fields.RxBytes = uint64(size)
|
||||||
|
} else {
|
||||||
|
fields.TxPackets = 1
|
||||||
|
fields.TxBytes = uint64(size)
|
||||||
|
}
|
||||||
|
t.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
39
client/firewall/uspfilter/conntrack/icmp_test.go
Normal file
39
client/firewall/uspfilter/conntrack/icmp_test.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
490
client/firewall/uspfilter/conntrack/tcp.go
Normal file
490
client/firewall/uspfilter/conntrack/tcp.go
Normal file
@@ -0,0 +1,490 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
// TODO: Send RST packets for invalid/timed-out connections
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// MSL (Maximum Segment Lifetime) is typically 2 minutes
|
||||||
|
MSL = 2 * time.Minute
|
||||||
|
// TimeWaitTimeout (TIME-WAIT) should last 2*MSL
|
||||||
|
TimeWaitTimeout = 2 * MSL
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
TCPSyn uint8 = 0x02
|
||||||
|
TCPAck uint8 = 0x10
|
||||||
|
TCPFin uint8 = 0x01
|
||||||
|
TCPRst uint8 = 0x04
|
||||||
|
TCPPush uint8 = 0x08
|
||||||
|
TCPUrg uint8 = 0x20
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultTCPTimeout is the default timeout for established TCP connections
|
||||||
|
DefaultTCPTimeout = 3 * time.Hour
|
||||||
|
// TCPHandshakeTimeout is timeout for TCP handshake completion
|
||||||
|
TCPHandshakeTimeout = 60 * time.Second
|
||||||
|
// TCPCleanupInterval is how often we check for stale connections
|
||||||
|
TCPCleanupInterval = 5 * time.Minute
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPState represents the state of a TCP connection
|
||||||
|
type TCPState int
|
||||||
|
|
||||||
|
func (s TCPState) String() string {
|
||||||
|
switch s {
|
||||||
|
case TCPStateNew:
|
||||||
|
return "New"
|
||||||
|
case TCPStateSynSent:
|
||||||
|
return "SYN Sent"
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
return "SYN Received"
|
||||||
|
case TCPStateEstablished:
|
||||||
|
return "Established"
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
return "FIN Wait 1"
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
return "FIN Wait 2"
|
||||||
|
case TCPStateClosing:
|
||||||
|
return "Closing"
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
return "Time Wait"
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
return "Close Wait"
|
||||||
|
case TCPStateLastAck:
|
||||||
|
return "Last ACK"
|
||||||
|
case TCPStateClosed:
|
||||||
|
return "Closed"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
const (
|
||||||
|
TCPStateNew TCPState = iota
|
||||||
|
TCPStateSynSent
|
||||||
|
TCPStateSynReceived
|
||||||
|
TCPStateEstablished
|
||||||
|
TCPStateFinWait1
|
||||||
|
TCPStateFinWait2
|
||||||
|
TCPStateClosing
|
||||||
|
TCPStateTimeWait
|
||||||
|
TCPStateCloseWait
|
||||||
|
TCPStateLastAck
|
||||||
|
TCPStateClosed
|
||||||
|
)
|
||||||
|
|
||||||
|
// TCPConnTrack represents a TCP connection state
|
||||||
|
type TCPConnTrack struct {
|
||||||
|
BaseConnTrack
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
|
State TCPState
|
||||||
|
established atomic.Bool
|
||||||
|
tombstone atomic.Bool
|
||||||
|
sync.RWMutex
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsEstablished safely checks if connection is established
|
||||||
|
func (t *TCPConnTrack) IsEstablished() bool {
|
||||||
|
return t.established.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetEstablished safely sets the established state
|
||||||
|
func (t *TCPConnTrack) SetEstablished(state bool) {
|
||||||
|
t.established.Store(state)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTombstone safely checks if the connection is marked for deletion
|
||||||
|
func (t *TCPConnTrack) IsTombstone() bool {
|
||||||
|
return t.tombstone.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTombstone safely marks the connection for deletion
|
||||||
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
|
t.tombstone.Store(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TCPTracker manages TCP connection states
|
||||||
|
type TCPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
connections map[ConnKey]*TCPConnTrack
|
||||||
|
mutex sync.RWMutex
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
tickerCancel context.CancelFunc
|
||||||
|
timeout time.Duration
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultTCPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
tracker := &TCPTracker{
|
||||||
|
logger: logger,
|
||||||
|
connections: make(map[ConnKey]*TCPConnTrack),
|
||||||
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
|
tickerCancel: cancel,
|
||||||
|
timeout: timeout,
|
||||||
|
flowLogger: flowLogger,
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.cleanupRoutine(ctx)
|
||||||
|
return tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
conn.Lock()
|
||||||
|
t.updateState(key, conn, flags, conn.Direction == nftypes.Egress)
|
||||||
|
conn.Unlock()
|
||||||
|
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound TCP connection
|
||||||
|
func (t *TCPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, 0, 0); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||||
|
func (t *TCPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *TCPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &TCPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.established.Store(false)
|
||||||
|
conn.tombstone.Store(false)
|
||||||
|
|
||||||
|
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||||
|
t.updateState(key, conn, flags, direction == nftypes.Egress)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
|
func (t *TCPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, flags uint8, size int) bool {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Handle RST flag specially - it always causes transition to closed
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
if conn.IsTombstone() {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Lock()
|
||||||
|
conn.SetTombstone()
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetEstablished(false)
|
||||||
|
conn.Unlock()
|
||||||
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection reset: %s", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.Lock()
|
||||||
|
t.updateState(key, conn, flags, false)
|
||||||
|
isEstablished := conn.IsEstablished()
|
||||||
|
isValidState := t.isValidStateForFlags(conn.State, flags)
|
||||||
|
conn.Unlock()
|
||||||
|
|
||||||
|
return isEstablished || isValidState
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateState updates the TCP connection state based on flags
|
||||||
|
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
|
state := conn.State
|
||||||
|
defer func() {
|
||||||
|
if state != conn.State {
|
||||||
|
t.logger.Trace("TCP connection %s transitioned from %s to %s", key, state, conn.State)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case TCPStateNew:
|
||||||
|
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||||
|
conn.State = TCPStateSynSent
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateSynSent:
|
||||||
|
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||||
|
if isOutbound {
|
||||||
|
conn.State = TCPStateEstablished
|
||||||
|
conn.SetEstablished(true)
|
||||||
|
} else {
|
||||||
|
// Simultaneous open
|
||||||
|
conn.State = TCPStateSynReceived
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||||
|
conn.State = TCPStateEstablished
|
||||||
|
conn.SetEstablished(true)
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateEstablished:
|
||||||
|
if flags&TCPFin != 0 {
|
||||||
|
if isOutbound {
|
||||||
|
conn.State = TCPStateFinWait1
|
||||||
|
} else {
|
||||||
|
conn.State = TCPStateCloseWait
|
||||||
|
}
|
||||||
|
conn.SetEstablished(false)
|
||||||
|
} else if flags&TCPRst != 0 {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
switch {
|
||||||
|
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||||
|
conn.State = TCPStateClosing
|
||||||
|
case flags&TCPFin != 0:
|
||||||
|
conn.State = TCPStateFinWait2
|
||||||
|
case flags&TCPAck != 0:
|
||||||
|
conn.State = TCPStateFinWait2
|
||||||
|
case flags&TCPRst != 0:
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
if flags&TCPFin != 0 {
|
||||||
|
conn.State = TCPStateTimeWait
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection %s completed", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateClosing:
|
||||||
|
if flags&TCPAck != 0 {
|
||||||
|
conn.State = TCPStateTimeWait
|
||||||
|
// Keep established = false from previous state
|
||||||
|
|
||||||
|
t.logger.Trace("TCP connection %s closed (simultaneous)", key)
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
if flags&TCPFin != 0 {
|
||||||
|
conn.State = TCPStateLastAck
|
||||||
|
}
|
||||||
|
|
||||||
|
case TCPStateLastAck:
|
||||||
|
if flags&TCPAck != 0 {
|
||||||
|
conn.State = TCPStateClosed
|
||||||
|
conn.SetTombstone()
|
||||||
|
|
||||||
|
// Send close event for gracefully closed connections
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
t.logger.Trace("TCP connection %s closed gracefully", key)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// isValidStateForFlags checks if the TCP flags are valid for the current connection state
|
||||||
|
func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
||||||
|
if !isValidFlagCombination(flags) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
switch state {
|
||||||
|
case TCPStateNew:
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||||
|
case TCPStateSynSent:
|
||||||
|
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateEstablished:
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
|
case TCPStateClosing:
|
||||||
|
// In CLOSING state, we should accept the final ACK
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
// In TIME_WAIT, we might see retransmissions
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
|
case TCPStateLastAck:
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
case TCPStateClosed:
|
||||||
|
// Accept retransmitted ACKs in closed state
|
||||||
|
// This is important because the final ACK might be lost
|
||||||
|
// and the peer will retransmit their FIN-ACK
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
|
defer t.cleanupTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.cleanupTicker.C:
|
||||||
|
t.cleanup()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) cleanup() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, conn := range t.connections {
|
||||||
|
if conn.IsTombstone() {
|
||||||
|
// Clean up tombstoned connections without sending an event
|
||||||
|
delete(t.connections, key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
var timeout time.Duration
|
||||||
|
switch {
|
||||||
|
case conn.State == TCPStateTimeWait:
|
||||||
|
timeout = TimeWaitTimeout
|
||||||
|
case conn.IsEstablished():
|
||||||
|
timeout = t.timeout
|
||||||
|
default:
|
||||||
|
timeout = TCPHandshakeTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
if conn.timeoutExceeded(timeout) {
|
||||||
|
// Return IPs to pool
|
||||||
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Trace("Cleaned up timed-out TCP connection %s", key)
|
||||||
|
|
||||||
|
// event already handled by state change
|
||||||
|
if conn.State != TCPStateTimeWait {
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine and releases resources
|
||||||
|
func (t *TCPTracker) Close() {
|
||||||
|
t.tickerCancel()
|
||||||
|
|
||||||
|
// Clean up all remaining IPs
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections = nil
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
func isValidFlagCombination(flags uint8) bool {
|
||||||
|
// Invalid: SYN+FIN
|
||||||
|
if flags&TCPSyn != 0 && flags&TCPFin != 0 {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
// Invalid: RST with SYN or FIN
|
||||||
|
if flags&TCPRst != 0 && (flags&TCPSyn != 0 || flags&TCPFin != 0) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
313
client/firewall/uspfilter/conntrack/tcp_test.go
Normal file
313
client/firewall/uspfilter/conntrack/tcp_test.go
Normal file
@@ -0,0 +1,313 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Security Tests", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
flags uint8
|
||||||
|
wantDrop bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Block unsolicited SYN-ACK",
|
||||||
|
flags: TCPSyn | TCPAck,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block SYN-ACK without prior SYN",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block invalid SYN-FIN",
|
||||||
|
flags: TCPSyn | TCPFin,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block invalid SYN-FIN combination",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block unsolicited RST",
|
||||||
|
flags: TCPRst,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block RST without connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block unsolicited ACK",
|
||||||
|
flags: TCPAck,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block ACK without connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block data without connection",
|
||||||
|
flags: TCPAck | TCPPush,
|
||||||
|
wantDrop: true,
|
||||||
|
desc: "Should block data without established connection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
|
||||||
|
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("Connection Flow Tests", func(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
test func(*testing.T)
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Normal Handshake",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// Send initial SYN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Receive SYN-ACK
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
|
// Send ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// Test data transfer
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
|
||||||
|
require.True(t, valid, "Data should be allowed after handshake")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Normal Close",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// First establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Send FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
|
// Receive ACK for FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid, "ACK for FIN should be allowed")
|
||||||
|
|
||||||
|
// Receive FIN from other side
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid, "FIN should be allowed")
|
||||||
|
|
||||||
|
// Send final ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RST During Connection",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// First establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Receive RST
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
|
require.True(t, valid, "RST should be allowed for established connection")
|
||||||
|
|
||||||
|
// Connection is logically dead but we don't enforce blocking subsequent packets
|
||||||
|
// The connection will be cleaned up by timeout
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Simultaneous Close",
|
||||||
|
test: func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
// First establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Both sides send FIN+ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||||
|
|
||||||
|
// Both sides send final ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid, "Final ACKs should be allowed")
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
tt.test(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestRSTHandling(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupState func()
|
||||||
|
sendRST func()
|
||||||
|
wantValid bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "RST in established",
|
||||||
|
setupState: func() {
|
||||||
|
// Establish connection first
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
},
|
||||||
|
sendRST: func() {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
|
},
|
||||||
|
wantValid: true,
|
||||||
|
desc: "Should accept RST for established connection",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RST without connection",
|
||||||
|
setupState: func() {},
|
||||||
|
sendRST: func() {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
|
},
|
||||||
|
wantValid: false,
|
||||||
|
desc: "Should reject RST without connection",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.setupState()
|
||||||
|
tt.sendRST()
|
||||||
|
|
||||||
|
// Verify connection state is as expected
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
if tt.wantValid {
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateClosed, conn.State)
|
||||||
|
require.False(t, conn.IsEstablished())
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Helper to establish a TCP connection
|
||||||
|
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
if i%2 == 0 {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
} else {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck, 0)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark connection cleanup
|
||||||
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger) // Short timeout for testing
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Pre-populate with expired connections
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for connections to expire
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.cleanup()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
219
client/firewall/uspfilter/conntrack/udp.go
Normal file
219
client/firewall/uspfilter/conntrack/udp.go
Normal file
@@ -0,0 +1,219 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
// DefaultUDPTimeout is the default timeout for UDP connections
|
||||||
|
DefaultUDPTimeout = 30 * time.Second
|
||||||
|
// UDPCleanupInterval is how often we check for stale connections
|
||||||
|
UDPCleanupInterval = 15 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// UDPConnTrack represents a UDP connection state
|
||||||
|
type UDPConnTrack struct {
|
||||||
|
BaseConnTrack
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
|
}
|
||||||
|
|
||||||
|
// UDPTracker manages UDP connection states
|
||||||
|
type UDPTracker struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
connections map[ConnKey]*UDPConnTrack
|
||||||
|
timeout time.Duration
|
||||||
|
cleanupTicker *time.Ticker
|
||||||
|
tickerCancel context.CancelFunc
|
||||||
|
mutex sync.RWMutex
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
|
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultUDPTimeout
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
|
tracker := &UDPTracker{
|
||||||
|
logger: logger,
|
||||||
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
|
timeout: timeout,
|
||||||
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
|
tickerCancel: cancel,
|
||||||
|
flowLogger: flowLogger,
|
||||||
|
}
|
||||||
|
|
||||||
|
go tracker.cleanupRoutine(ctx)
|
||||||
|
return tracker
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound UDP connection
|
||||||
|
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound records an inbound UDP connection
|
||||||
|
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if exists {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &UDPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
|
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
|
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanupRoutine periodically removes stale connections
|
||||||
|
func (t *UDPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
|
defer t.cleanupTicker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-t.cleanupTicker.C:
|
||||||
|
t.cleanup()
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) cleanup() {
|
||||||
|
t.mutex.Lock()
|
||||||
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
|
for key, conn := range t.connections {
|
||||||
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
|
delete(t.connections, key)
|
||||||
|
|
||||||
|
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Close stops the cleanup routine and releases resources
|
||||||
|
func (t *UDPTracker) Close() {
|
||||||
|
t.tickerCancel()
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections = nil
|
||||||
|
t.mutex.Unlock()
|
||||||
|
}
|
||||||
|
|
||||||
|
// GetConnection safely retrieves a connection state
|
||||||
|
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||||
|
t.mutex.RLock()
|
||||||
|
defer t.mutex.RUnlock()
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
return conn, exists
|
||||||
|
}
|
||||||
|
|
||||||
|
// Timeout returns the configured timeout duration for the tracker
|
||||||
|
func (t *UDPTracker) Timeout() time.Duration {
|
||||||
|
return t.timeout
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
252
client/firewall/uspfilter/conntrack/udp_test.go
Normal file
252
client/firewall/uspfilter/conntrack/udp_test.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestNewUDPTracker(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
timeout time.Duration
|
||||||
|
wantTimeout time.Duration
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "with custom timeout",
|
||||||
|
timeout: 1 * time.Minute,
|
||||||
|
wantTimeout: 1 * time.Minute,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "with zero timeout uses default",
|
||||||
|
timeout: 0,
|
||||||
|
wantTimeout: DefaultUDPTimeout,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
|
||||||
|
assert.NotNil(t, tracker)
|
||||||
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
|
assert.NotNil(t, tracker.connections)
|
||||||
|
assert.NotNil(t, tracker.cleanupTicker)
|
||||||
|
assert.NotNil(t, tracker.tickerCancel)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(53)
|
||||||
|
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
|
// Verify connection was tracked
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn, exists := tracker.connections[key]
|
||||||
|
require.True(t, exists)
|
||||||
|
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||||
|
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
|
||||||
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
|
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(53)
|
||||||
|
|
||||||
|
// Track outbound connection
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
srcIP netip.Addr
|
||||||
|
dstIP netip.Addr
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
sleep time.Duration
|
||||||
|
want bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid inbound response",
|
||||||
|
srcIP: dstIP, // Original destination is now source
|
||||||
|
dstIP: srcIP, // Original source is now destination
|
||||||
|
srcPort: dstPort, // Original destination port is now source
|
||||||
|
dstPort: srcPort, // Original source port is now destination
|
||||||
|
sleep: 0,
|
||||||
|
want: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid source IP",
|
||||||
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid destination IP",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid source port",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: 54321,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "invalid destination port",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: 54321,
|
||||||
|
sleep: 0,
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "expired connection",
|
||||||
|
srcIP: dstIP,
|
||||||
|
dstIP: srcIP,
|
||||||
|
srcPort: dstPort,
|
||||||
|
dstPort: srcPort,
|
||||||
|
sleep: 2 * time.Second, // Longer than tracker timeout
|
||||||
|
want: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
if tt.sleep > 0 {
|
||||||
|
time.Sleep(tt.sleep)
|
||||||
|
}
|
||||||
|
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
|
||||||
|
assert.Equal(t, tt.want, got)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUDPTracker_Cleanup(t *testing.T) {
|
||||||
|
// Use shorter intervals for testing
|
||||||
|
timeout := 50 * time.Millisecond
|
||||||
|
cleanupInterval := 25 * time.Millisecond
|
||||||
|
|
||||||
|
ctx, tickerCancel := context.WithCancel(context.Background())
|
||||||
|
defer tickerCancel()
|
||||||
|
|
||||||
|
// Create tracker with custom cleanup interval
|
||||||
|
tracker := &UDPTracker{
|
||||||
|
connections: make(map[ConnKey]*UDPConnTrack),
|
||||||
|
timeout: timeout,
|
||||||
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
|
tickerCancel: tickerCancel,
|
||||||
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
|
}
|
||||||
|
|
||||||
|
// Start cleanup routine
|
||||||
|
go tracker.cleanupRoutine(ctx)
|
||||||
|
|
||||||
|
// Add some connections
|
||||||
|
connections := []struct {
|
||||||
|
srcIP netip.Addr
|
||||||
|
dstIP netip.Addr
|
||||||
|
srcPort uint16
|
||||||
|
dstPort uint16
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
|
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
|
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||||
|
srcPort: 12346,
|
||||||
|
dstPort: 53,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, conn := range connections {
|
||||||
|
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify initial connections
|
||||||
|
assert.Len(t, tracker.connections, 2)
|
||||||
|
|
||||||
|
// Wait for connection timeout and cleanup interval
|
||||||
|
time.Sleep(timeout + 2*cleanupInterval)
|
||||||
|
|
||||||
|
tracker.mutex.RLock()
|
||||||
|
connCount := len(tracker.connections)
|
||||||
|
tracker.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Verify connections were cleaned up
|
||||||
|
assert.Equal(t, 0, connCount, "Expected all connections to be cleaned up")
|
||||||
|
|
||||||
|
// Properly close the tracker
|
||||||
|
tracker.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
90
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
90
client/firewall/uspfilter/forwarder/endpoint.go
Normal file
@@ -0,0 +1,90 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
// endpoint implements stack.LinkEndpoint and handles integration with the wireguard device
|
||||||
|
type endpoint struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
dispatcher stack.NetworkDispatcher
|
||||||
|
device *wgdevice.Device
|
||||||
|
mtu uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Attach(dispatcher stack.NetworkDispatcher) {
|
||||||
|
e.dispatcher = dispatcher
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) IsAttached() bool {
|
||||||
|
return e.dispatcher != nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MTU() uint32 {
|
||||||
|
return e.mtu
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Capabilities() stack.LinkEndpointCapabilities {
|
||||||
|
return stack.CapabilityNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) MaxHeaderLength() uint16 {
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) LinkAddress() tcpip.LinkAddress {
|
||||||
|
return ""
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) WritePackets(pkts stack.PacketBufferList) (int, tcpip.Error) {
|
||||||
|
var written int
|
||||||
|
for _, pkt := range pkts.AsSlice() {
|
||||||
|
netHeader := header.IPv4(pkt.NetworkHeader().View().AsSlice())
|
||||||
|
|
||||||
|
data := stack.PayloadSince(pkt.NetworkHeader())
|
||||||
|
if data == nil {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
// Send the packet through WireGuard
|
||||||
|
address := netHeader.DestinationAddress()
|
||||||
|
err := e.device.CreateOutboundPacket(data.AsSlice(), address.AsSlice())
|
||||||
|
if err != nil {
|
||||||
|
e.logger.Error("CreateOutboundPacket: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
written++
|
||||||
|
}
|
||||||
|
|
||||||
|
return written, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) Wait() {
|
||||||
|
// not required
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ARPHardwareType() header.ARPHardwareType {
|
||||||
|
return header.ARPHardwareNone
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
||||||
|
// not required
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
type epID stack.TransportEndpointID
|
||||||
|
|
||||||
|
func (i epID) String() string {
|
||||||
|
// src and remote is swapped
|
||||||
|
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
|
}
|
||||||
169
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
169
client/firewall/uspfilter/forwarder/forwarder.go
Normal file
@@ -0,0 +1,169 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"runtime"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"gvisor.dev/gvisor/pkg/buffer"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/network/ipv4"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/icmp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
defaultReceiveWindow = 32768
|
||||||
|
defaultMaxInFlight = 1024
|
||||||
|
iosReceiveWindow = 16384
|
||||||
|
iosMaxInFlight = 256
|
||||||
|
)
|
||||||
|
|
||||||
|
type Forwarder struct {
|
||||||
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
|
stack *stack.Stack
|
||||||
|
endpoint *endpoint
|
||||||
|
udpForwarder *udpForwarder
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
ip net.IP
|
||||||
|
netstack bool
|
||||||
|
}
|
||||||
|
|
||||||
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||||
|
s := stack.New(stack.Options{
|
||||||
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
|
tcp.NewProtocol,
|
||||||
|
udp.NewProtocol,
|
||||||
|
icmp.NewProtocol4,
|
||||||
|
},
|
||||||
|
HandleLocal: false,
|
||||||
|
})
|
||||||
|
|
||||||
|
mtu, err := iface.GetDevice().MTU()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get MTU: %w", err)
|
||||||
|
}
|
||||||
|
nicID := tcpip.NICID(1)
|
||||||
|
endpoint := &endpoint{
|
||||||
|
logger: logger,
|
||||||
|
device: iface.GetWGDevice(),
|
||||||
|
mtu: uint32(mtu),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.CreateNIC(nicID, endpoint); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to create NIC: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ones, _ := iface.Address().Network.Mask.Size()
|
||||||
|
protoAddr := tcpip.ProtocolAddress{
|
||||||
|
Protocol: ipv4.ProtocolNumber,
|
||||||
|
AddressWithPrefix: tcpip.AddressWithPrefix{
|
||||||
|
Address: tcpip.AddrFromSlice(iface.Address().IP.To4()),
|
||||||
|
PrefixLen: ones,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.AddProtocolAddress(nicID, protoAddr, stack.AddressProperties{}); err != nil {
|
||||||
|
return nil, fmt.Errorf("failed to add protocol address: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defaultSubnet, err := tcpip.NewSubnet(
|
||||||
|
tcpip.AddrFrom4([4]byte{0, 0, 0, 0}),
|
||||||
|
tcpip.MaskFromBytes([]byte{0, 0, 0, 0}),
|
||||||
|
)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("creating default subnet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := s.SetPromiscuousMode(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set promiscuous mode: %s", err)
|
||||||
|
}
|
||||||
|
if err := s.SetSpoofing(nicID, true); err != nil {
|
||||||
|
return nil, fmt.Errorf("set spoofing: %s", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
s.SetRouteTable([]tcpip.Route{
|
||||||
|
{
|
||||||
|
Destination: defaultSubnet,
|
||||||
|
NIC: nicID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
f := &Forwarder{
|
||||||
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
|
stack: s,
|
||||||
|
endpoint: endpoint,
|
||||||
|
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
netstack: netstack,
|
||||||
|
ip: iface.Address().IP,
|
||||||
|
}
|
||||||
|
|
||||||
|
receiveWindow := defaultReceiveWindow
|
||||||
|
maxInFlight := defaultMaxInFlight
|
||||||
|
if runtime.GOOS == "ios" {
|
||||||
|
receiveWindow = iosReceiveWindow
|
||||||
|
maxInFlight = iosMaxInFlight
|
||||||
|
}
|
||||||
|
|
||||||
|
tcpForwarder := tcp.NewForwarder(s, receiveWindow, maxInFlight, f.handleTCP)
|
||||||
|
s.SetTransportProtocolHandler(tcp.ProtocolNumber, tcpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
udpForwarder := udp.NewForwarder(s, f.handleUDP)
|
||||||
|
s.SetTransportProtocolHandler(udp.ProtocolNumber, udpForwarder.HandlePacket)
|
||||||
|
|
||||||
|
s.SetTransportProtocolHandler(icmp.ProtocolNumber4, f.handleICMP)
|
||||||
|
|
||||||
|
log.Debugf("forwarder: Initialization complete with NIC %d", nicID)
|
||||||
|
return f, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) InjectIncomingPacket(payload []byte) error {
|
||||||
|
if len(payload) < header.IPv4MinimumSize {
|
||||||
|
return fmt.Errorf("packet too small: %d bytes", len(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
pkt := stack.NewPacketBuffer(stack.PacketBufferOptions{
|
||||||
|
Payload: buffer.MakeWithData(payload),
|
||||||
|
})
|
||||||
|
defer pkt.DecRef()
|
||||||
|
|
||||||
|
if f.endpoint.dispatcher != nil {
|
||||||
|
f.endpoint.dispatcher.DeliverNetworkPacket(ipv4.ProtocolNumber, pkt)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the forwarder
|
||||||
|
func (f *Forwarder) Stop() {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
if f.udpForwarder != nil {
|
||||||
|
f.udpForwarder.Stop()
|
||||||
|
}
|
||||||
|
|
||||||
|
f.stack.Close()
|
||||||
|
f.stack.Wait()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) determineDialAddr(addr tcpip.Address) net.IP {
|
||||||
|
if f.netstack && f.ip.Equal(addr.AsSlice()) {
|
||||||
|
return net.IPv4(127, 0, 0, 1)
|
||||||
|
}
|
||||||
|
return addr.AsSlice()
|
||||||
|
}
|
||||||
127
client/firewall/uspfilter/forwarder/icmp.go
Normal file
127
client/firewall/uspfilter/forwarder/icmp.go
Normal file
@@ -0,0 +1,127 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleICMP handles ICMP packets from the network stack
|
||||||
|
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||||
|
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||||
|
icmpType := uint8(icmpHdr.Type())
|
||||||
|
icmpCode := uint8(icmpHdr.Code())
|
||||||
|
|
||||||
|
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||||
|
// dont process our own replies
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
lc := net.ListenConfig{}
|
||||||
|
// TODO: support non-root
|
||||||
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||||
|
|
||||||
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
if err := conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("Failed to close ICMP socket: %v", err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
|
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||||
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
|
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||||
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
|
// For Echo Requests, send and handle response
|
||||||
|
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||||
|
f.handleEchoResponse(icmpHdr, conn, id)
|
||||||
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||||
|
}
|
||||||
|
|
||||||
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||||
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
|
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
response := make([]byte, f.endpoint.mtu)
|
||||||
|
n, _, err := conn.ReadFrom(response)
|
||||||
|
if err != nil {
|
||||||
|
if !isTimeout(err) {
|
||||||
|
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
|
ip := header.IPv4(ipHdr)
|
||||||
|
ip.Encode(&header.IPv4Fields{
|
||||||
|
TotalLength: uint16(header.IPv4MinimumSize + n),
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: uint8(header.ICMPv4ProtocolNumber),
|
||||||
|
SrcAddr: id.LocalAddress,
|
||||||
|
DstAddr: id.RemoteAddress,
|
||||||
|
})
|
||||||
|
ip.SetChecksum(^ip.CalculateChecksum())
|
||||||
|
|
||||||
|
fullPacket := make([]byte, 0, len(ipHdr)+n)
|
||||||
|
fullPacket = append(fullPacket, ipHdr...)
|
||||||
|
fullPacket = append(fullPacket, response[:n]...)
|
||||||
|
|
||||||
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
|
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||||
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendICMPEvent stores flow events for ICMP packets
|
||||||
|
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
||||||
|
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
|
||||||
|
// TODO: get packets/bytes
|
||||||
|
})
|
||||||
|
}
|
||||||
132
client/firewall/uspfilter/forwarder/tcp.go
Normal file
132
client/firewall/uspfilter/forwarder/tcp.go
Normal file
@@ -0,0 +1,132 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
|
if err != nil {
|
||||||
|
r.Complete(true)
|
||||||
|
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
|
||||||
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
|
if epErr != nil {
|
||||||
|
f.logger.Error("forwarder: failed to create TCP endpoint: %v", epErr)
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
r.Complete(true)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Complete the handshake
|
||||||
|
r.Complete(false)
|
||||||
|
|
||||||
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
|
||||||
|
|
||||||
|
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||||
|
defer func() {
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
|
}
|
||||||
|
ep.Close()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Create context for managing the proxy goroutines
|
||||||
|
ctx, cancel := context.WithCancel(f.ctx)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(outConn, inConn)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
_, err := io.Copy(inConn, outConn)
|
||||||
|
errChan <- err
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
332
client/firewall/uspfilter/forwarder/udp.go
Normal file
332
client/firewall/uspfilter/forwarder/udp.go
Normal file
@@ -0,0 +1,332 @@
|
|||||||
|
package forwarder
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
"gvisor.dev/gvisor/pkg/tcpip/transport/udp"
|
||||||
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
udpTimeout = 30 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
type udpPacketConn struct {
|
||||||
|
conn *gonet.UDPConn
|
||||||
|
outConn net.Conn
|
||||||
|
lastSeen atomic.Int64
|
||||||
|
cancel context.CancelFunc
|
||||||
|
ep tcpip.Endpoint
|
||||||
|
flowID uuid.UUID
|
||||||
|
}
|
||||||
|
|
||||||
|
type udpForwarder struct {
|
||||||
|
sync.RWMutex
|
||||||
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
|
bufPool sync.Pool
|
||||||
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
|
}
|
||||||
|
|
||||||
|
type idleConn struct {
|
||||||
|
id stack.TransportEndpointID
|
||||||
|
conn *udpPacketConn
|
||||||
|
}
|
||||||
|
|
||||||
|
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||||
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
f := &udpForwarder{
|
||||||
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
|
bufPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
b := make([]byte, mtu)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
go f.cleanup()
|
||||||
|
return f
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop stops the UDP forwarder and all active connections
|
||||||
|
func (f *udpForwarder) Stop() {
|
||||||
|
f.cancel()
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
defer f.Unlock()
|
||||||
|
|
||||||
|
for id, conn := range f.conns {
|
||||||
|
conn.cancel()
|
||||||
|
if err := conn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
|
||||||
|
}
|
||||||
|
if err := conn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.ep.Close()
|
||||||
|
delete(f.conns, id)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// cleanup periodically removes idle UDP connections
|
||||||
|
func (f *udpForwarder) cleanup() {
|
||||||
|
ticker := time.NewTicker(time.Minute)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-f.ctx.Done():
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
var idleConns []idleConn
|
||||||
|
|
||||||
|
f.RLock()
|
||||||
|
for id, conn := range f.conns {
|
||||||
|
if conn.getIdleDuration() > udpTimeout {
|
||||||
|
idleConns = append(idleConns, idleConn{id, conn})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
f.RUnlock()
|
||||||
|
|
||||||
|
for _, idle := range idleConns {
|
||||||
|
idle.conn.cancel()
|
||||||
|
if err := idle.conn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
|
||||||
|
}
|
||||||
|
if err := idle.conn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
idle.conn.ep.Close()
|
||||||
|
|
||||||
|
f.Lock()
|
||||||
|
delete(f.conns, idle.id)
|
||||||
|
f.Unlock()
|
||||||
|
|
||||||
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleUDP is called by the UDP forwarder for new packets
|
||||||
|
func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
||||||
|
if f.ctx.Err() != nil {
|
||||||
|
f.logger.Trace("forwarder: context done, dropping UDP packet")
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
id := r.ID()
|
||||||
|
|
||||||
|
f.udpForwarder.RLock()
|
||||||
|
_, exists := f.udpForwarder.conns[id]
|
||||||
|
f.udpForwarder.RUnlock()
|
||||||
|
if exists {
|
||||||
|
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
|
if err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||||
|
// TODO: Send ICMP error message
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create wait queue for blocking syscalls
|
||||||
|
wq := waiter.Queue{}
|
||||||
|
ep, epErr := r.CreateEndpoint(&wq)
|
||||||
|
if epErr != nil {
|
||||||
|
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
inConn := gonet.NewUDPConn(f.stack, &wq, ep)
|
||||||
|
connCtx, connCancel := context.WithCancel(f.ctx)
|
||||||
|
|
||||||
|
pConn := &udpPacketConn{
|
||||||
|
conn: inConn,
|
||||||
|
outConn: outConn,
|
||||||
|
cancel: connCancel,
|
||||||
|
ep: ep,
|
||||||
|
flowID: flowID,
|
||||||
|
}
|
||||||
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
// Double-check no connection was created while we were setting up
|
||||||
|
if _, exists := f.udpForwarder.conns[id]; exists {
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
pConn.cancel()
|
||||||
|
if err := inConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
|
}
|
||||||
|
if err := outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return
|
||||||
|
}
|
||||||
|
f.udpForwarder.conns[id] = pConn
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
|
||||||
|
|
||||||
|
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
defer func() {
|
||||||
|
pConn.cancel()
|
||||||
|
if err := pConn.conn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
|
}
|
||||||
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
ep.Close()
|
||||||
|
|
||||||
|
f.udpForwarder.Lock()
|
||||||
|
delete(f.udpForwarder.conns, id)
|
||||||
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||||
|
}()
|
||||||
|
|
||||||
|
errChan := make(chan error, 2)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- pConn.copy(ctx, pConn.conn, pConn.outConn, &f.udpForwarder.bufPool, "outbound->inbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
errChan <- pConn.copy(ctx, pConn.outConn, pConn.conn, &f.udpForwarder.bufPool, "inbound->outbound")
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
||||||
|
return
|
||||||
|
case err := <-errChan:
|
||||||
|
if err != nil && !isClosedError(err) {
|
||||||
|
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||||
|
}
|
||||||
|
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) getIdleDuration() time.Duration {
|
||||||
|
lastSeen := time.Unix(0, c.lastSeen.Load())
|
||||||
|
return time.Since(lastSeen)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *udpPacketConn) copy(ctx context.Context, dst net.Conn, src net.Conn, bufPool *sync.Pool, direction string) error {
|
||||||
|
bufp := bufPool.Get().(*[]byte)
|
||||||
|
defer bufPool.Put(bufp)
|
||||||
|
buffer := *bufp
|
||||||
|
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return ctx.Err()
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := src.SetDeadline(time.Now().Add(udpTimeout)); err != nil {
|
||||||
|
return fmt.Errorf("set read deadline: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
n, err := src.Read(buffer)
|
||||||
|
if err != nil {
|
||||||
|
if isTimeout(err) {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
return fmt.Errorf("read from %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
_, err = dst.Write(buffer[:n])
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("write to %s: %w", direction, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
c.updateLastSeen()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func isClosedError(err error) bool {
|
||||||
|
return errors.Is(err, net.ErrClosed) || errors.Is(err, context.Canceled)
|
||||||
|
}
|
||||||
|
|
||||||
|
func isTimeout(err error) bool {
|
||||||
|
var netErr net.Error
|
||||||
|
if errors.As(err, &netErr) {
|
||||||
|
return netErr.Timeout()
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
131
client/firewall/uspfilter/localip.go
Normal file
131
client/firewall/uspfilter/localip.go
Normal file
@@ -0,0 +1,131 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
|
)
|
||||||
|
|
||||||
|
type localIPManager struct {
|
||||||
|
mu sync.RWMutex
|
||||||
|
|
||||||
|
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
||||||
|
ipv4Bitmap [1 << 16]uint32
|
||||||
|
}
|
||||||
|
|
||||||
|
func newLocalIPManager() *localIPManager {
|
||||||
|
return &localIPManager{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) setBitmapBit(ip net.IP) {
|
||||||
|
ipv4 := ip.To4()
|
||||||
|
if ipv4 == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||||
|
high := (uint16(ip[0]) << 8) | uint16(ip[1])
|
||||||
|
low := (uint16(ip[2]) << 8) | uint16(ip[3])
|
||||||
|
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||||
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
|
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
||||||
|
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
||||||
|
if int(high) >= len(*newIPv4Bitmap) {
|
||||||
|
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
||||||
|
}
|
||||||
|
ipStr := ip.String()
|
||||||
|
if _, exists := ipv4Set[ipStr]; !exists {
|
||||||
|
ipv4Set[ipStr] = struct{}{}
|
||||||
|
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||||
|
newIPv4Bitmap[high] |= 1 << (low % 32)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
if err != nil {
|
||||||
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
|
log.Debugf("process IP failed: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
err = fmt.Errorf("panic: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
var newIPv4Bitmap [1 << 16]uint32
|
||||||
|
ipv4Set := make(map[string]struct{})
|
||||||
|
var ipv4Addresses []string
|
||||||
|
|
||||||
|
// 127.0.0.0/8
|
||||||
|
high := uint16(127) << 8
|
||||||
|
for i := uint16(0); i < 256; i++ {
|
||||||
|
newIPv4Bitmap[high|i] = 0xffffffff
|
||||||
|
}
|
||||||
|
|
||||||
|
if iface != nil {
|
||||||
|
if err := m.processIP(iface.Address().IP, &newIPv4Bitmap, ipv4Set, &ipv4Addresses); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to get interfaces: %v", err)
|
||||||
|
} else {
|
||||||
|
for _, intf := range interfaces {
|
||||||
|
m.processInterface(intf, &newIPv4Bitmap, ipv4Set, &ipv4Addresses)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.ipv4Bitmap = newIPv4Bitmap
|
||||||
|
m.mu.Unlock()
|
||||||
|
|
||||||
|
log.Debugf("Local IPv4 addresses: %v", ipv4Addresses)
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||||
|
m.mu.RLock()
|
||||||
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
|
if ip.Is4() {
|
||||||
|
return m.checkBitmapBit(ip.AsSlice())
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
|
}
|
||||||
271
client/firewall/uspfilter/localip_test.go
Normal file
271
client/firewall/uspfilter/localip_test.go
Normal file
@@ -0,0 +1,271 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestLocalIPManager(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
setupAddr wgaddr.Address
|
||||||
|
testIP netip.Addr
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Localhost range",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Localhost standard address",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Localhost range edge",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP matches",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
|
expected: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP doesn't match",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "IPv6 address",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: net.ParseIP("fe80::1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("fe80::"),
|
||||||
|
Mask: net.CIDRMask(64, 128),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
manager := newLocalIPManager()
|
||||||
|
|
||||||
|
mock := &IFaceMock{
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return tt.setupAddr
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
err := manager.UpdateLocalIPs(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result := manager.IsLocalIP(tt.testIP)
|
||||||
|
require.Equal(t, tt.expected, result)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
||||||
|
manager := newLocalIPManager()
|
||||||
|
mock := &IFaceMock{}
|
||||||
|
|
||||||
|
// Get actual local interfaces
|
||||||
|
interfaces, err := net.Interfaces()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
var tests []struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add all local interface IPs to test cases
|
||||||
|
for _, iface := range interfaces {
|
||||||
|
addrs, err := iface.Addrs()
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
for _, addr := range addrs {
|
||||||
|
var ip net.IP
|
||||||
|
switch v := addr.(type) {
|
||||||
|
case *net.IPNet:
|
||||||
|
ip = v.IP
|
||||||
|
case *net.IPAddr:
|
||||||
|
ip = v.IP
|
||||||
|
default:
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
|
if ip4 := ip.To4(); ip4 != nil {
|
||||||
|
tests = append(tests, struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
ip: ip4.String(),
|
||||||
|
expected: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Add some external IPs as negative test cases
|
||||||
|
externalIPs := []string{
|
||||||
|
"8.8.8.8",
|
||||||
|
"1.1.1.1",
|
||||||
|
"208.67.222.222",
|
||||||
|
}
|
||||||
|
for _, ip := range externalIPs {
|
||||||
|
tests = append(tests, struct {
|
||||||
|
ip string
|
||||||
|
expected bool
|
||||||
|
}{
|
||||||
|
ip: ip,
|
||||||
|
expected: false,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
require.NotEmpty(t, tests, "No test cases generated")
|
||||||
|
|
||||||
|
err = manager.UpdateLocalIPs(mock)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
t.Logf("Testing %d IPs", len(tests))
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
|
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||||
|
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// MapImplementation is a version using map[string]struct{}
|
||||||
|
type MapImplementation struct {
|
||||||
|
localIPs map[string]struct{}
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkIPChecks(b *testing.B) {
|
||||||
|
interfaces := make([]net.IP, 16)
|
||||||
|
for i := range interfaces {
|
||||||
|
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup bitmap version
|
||||||
|
bitmapManager := &localIPManager{
|
||||||
|
ipv4Bitmap: [1 << 16]uint32{},
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||||
|
bitmapManager.setBitmapBit(ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Setup map version
|
||||||
|
mapManager := &MapImplementation{
|
||||||
|
localIPs: make(map[string]struct{}),
|
||||||
|
}
|
||||||
|
for _, ip := range interfaces[:8] {
|
||||||
|
mapManager.localIPs[ip.String()] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
b.Run("Bitmap_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Bitmap_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bitmapManager.checkBitmapBit(ip)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Hit", func(b *testing.B) {
|
||||||
|
ip := interfaces[4]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("Map_Miss", func(b *testing.B) {
|
||||||
|
ip := interfaces[12]
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
// nolint:gosimple
|
||||||
|
_, _ = mapManager.localIPs[ip.String()]
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkWGPosition(b *testing.B) {
|
||||||
|
wgIP := net.ParseIP("10.10.0.1")
|
||||||
|
|
||||||
|
// Create two managers - one checks WG IP first, other checks it last
|
||||||
|
b.Run("WG_First", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
bm.setBitmapBit(wgIP)
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("WG_Last", func(b *testing.B) {
|
||||||
|
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
||||||
|
// Fill with other IPs first
|
||||||
|
for i := 0; i < 15; i++ {
|
||||||
|
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||||
|
}
|
||||||
|
bm.setBitmapBit(wgIP) // Add WG IP last
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
bm.checkBitmapBit(wgIP)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
252
client/firewall/uspfilter/log/log.go
Normal file
252
client/firewall/uspfilter/log/log.go
Normal file
@@ -0,0 +1,252 @@
|
|||||||
|
// Package log provides a high-performance, non-blocking logger for userspace networking
|
||||||
|
package log
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"fmt"
|
||||||
|
"io"
|
||||||
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
maxBatchSize = 1024 * 16
|
||||||
|
maxMessageSize = 1024 * 2
|
||||||
|
defaultFlushInterval = 2 * time.Second
|
||||||
|
logChannelSize = 1000
|
||||||
|
)
|
||||||
|
|
||||||
|
type Level uint32
|
||||||
|
|
||||||
|
const (
|
||||||
|
LevelPanic Level = iota
|
||||||
|
LevelFatal
|
||||||
|
LevelError
|
||||||
|
LevelWarn
|
||||||
|
LevelInfo
|
||||||
|
LevelDebug
|
||||||
|
LevelTrace
|
||||||
|
)
|
||||||
|
|
||||||
|
var levelStrings = map[Level]string{
|
||||||
|
LevelPanic: "PANC",
|
||||||
|
LevelFatal: "FATL",
|
||||||
|
LevelError: "ERRO",
|
||||||
|
LevelWarn: "WARN",
|
||||||
|
LevelInfo: "INFO",
|
||||||
|
LevelDebug: "DEBG",
|
||||||
|
LevelTrace: "TRAC",
|
||||||
|
}
|
||||||
|
|
||||||
|
type logMessage struct {
|
||||||
|
level Level
|
||||||
|
format string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
|
||||||
|
// Logger is a high-performance, non-blocking logger
|
||||||
|
type Logger struct {
|
||||||
|
output io.Writer
|
||||||
|
level atomic.Uint32
|
||||||
|
msgChannel chan logMessage
|
||||||
|
shutdown chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
bufPool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
|
||||||
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
|
l := &Logger{
|
||||||
|
output: logrusLogger.Out,
|
||||||
|
msgChannel: make(chan logMessage, logChannelSize),
|
||||||
|
shutdown: make(chan struct{}),
|
||||||
|
bufPool: sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
b := make([]byte, 0, maxMessageSize)
|
||||||
|
return &b
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
logrusLevel := logrusLogger.GetLevel()
|
||||||
|
l.level.Store(uint32(logrusLevel))
|
||||||
|
level := levelStrings[Level(logrusLevel)]
|
||||||
|
log.Debugf("New uspfilter logger created with loglevel %v", level)
|
||||||
|
|
||||||
|
l.wg.Add(1)
|
||||||
|
go l.worker()
|
||||||
|
|
||||||
|
return l
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetLevel sets the logging level
|
||||||
|
func (l *Logger) SetLevel(level Level) {
|
||||||
|
l.level.Store(uint32(level))
|
||||||
|
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) log(level Level, format string, args ...any) {
|
||||||
|
select {
|
||||||
|
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
|
||||||
|
default:
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Error logs a message at error level
|
||||||
|
func (l *Logger) Error(format string, args ...any) {
|
||||||
|
if l.level.Load() >= uint32(LevelError) {
|
||||||
|
l.log(LevelError, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Warn logs a message at warning level
|
||||||
|
func (l *Logger) Warn(format string, args ...any) {
|
||||||
|
if l.level.Load() >= uint32(LevelWarn) {
|
||||||
|
l.log(LevelWarn, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Info logs a message at info level
|
||||||
|
func (l *Logger) Info(format string, args ...any) {
|
||||||
|
if l.level.Load() >= uint32(LevelInfo) {
|
||||||
|
l.log(LevelInfo, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Debug logs a message at debug level
|
||||||
|
func (l *Logger) Debug(format string, args ...any) {
|
||||||
|
if l.level.Load() >= uint32(LevelDebug) {
|
||||||
|
l.log(LevelDebug, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Trace logs a message at trace level
|
||||||
|
func (l *Logger) Trace(format string, args ...any) {
|
||||||
|
if l.level.Load() >= uint32(LevelTrace) {
|
||||||
|
l.log(LevelTrace, format, args...)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
*buf = append(*buf, levelStrings[level]...)
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
var msg string
|
||||||
|
if len(args) > 0 {
|
||||||
|
msg = fmt.Sprintf(format, args...)
|
||||||
|
} else {
|
||||||
|
msg = format
|
||||||
|
}
|
||||||
|
*buf = append(*buf, msg...)
|
||||||
|
*buf = append(*buf, '\n')
|
||||||
|
|
||||||
|
if len(*buf) > maxMessageSize {
|
||||||
|
*buf = (*buf)[:maxMessageSize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processMessage handles a single log message and adds it to the buffer
|
||||||
|
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
|
||||||
|
bufp := l.bufPool.Get().(*[]byte)
|
||||||
|
defer l.bufPool.Put(bufp)
|
||||||
|
|
||||||
|
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
|
||||||
|
|
||||||
|
if len(*buffer)+len(*bufp) > maxBatchSize {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
*buffer = append(*buffer, *bufp...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushBuffer writes the accumulated buffer to output
|
||||||
|
func (l *Logger) flushBuffer(buffer *[]byte) {
|
||||||
|
if len(*buffer) > 0 {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processBatch processes as many messages as possible without blocking
|
||||||
|
func (l *Logger) processBatch(buffer *[]byte) {
|
||||||
|
for len(*buffer) < maxBatchSize {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleShutdown manages the graceful shutdown sequence with timeout
|
||||||
|
func (l *Logger) handleShutdown(buffer *[]byte) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
case <-ctx.Done():
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(l.msgChannel) == 0 {
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker is the main goroutine that processes log messages
|
||||||
|
func (l *Logger) worker() {
|
||||||
|
defer l.wg.Done()
|
||||||
|
|
||||||
|
ticker := time.NewTicker(defaultFlushInterval)
|
||||||
|
defer ticker.Stop()
|
||||||
|
|
||||||
|
buffer := make([]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-l.shutdown:
|
||||||
|
l.handleShutdown(&buffer)
|
||||||
|
return
|
||||||
|
case <-ticker.C:
|
||||||
|
l.flushBuffer(&buffer)
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, &buffer)
|
||||||
|
l.processBatch(&buffer)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Stop gracefully shuts down the logger
|
||||||
|
func (l *Logger) Stop(ctx context.Context) error {
|
||||||
|
done := make(chan struct{})
|
||||||
|
|
||||||
|
l.closeOnce.Do(func() {
|
||||||
|
close(l.shutdown)
|
||||||
|
})
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
l.wg.Wait()
|
||||||
|
close(done)
|
||||||
|
}()
|
||||||
|
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return ctx.Err()
|
||||||
|
case <-done:
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
}
|
||||||
121
client/firewall/uspfilter/log/log_test.go
Normal file
121
client/firewall/uspfilter/log/log_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package log_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type discard struct{}
|
||||||
|
|
||||||
|
func (d *discard) Write(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger(b *testing.B) {
|
||||||
|
simpleMessage := "Connection established"
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4 // TCPStateEstablished
|
||||||
|
|
||||||
|
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
|
||||||
|
protocol := "TCP"
|
||||||
|
direction := "outbound"
|
||||||
|
flags := uint16(0x18) // ACK + PSH
|
||||||
|
sequence := uint32(123456789)
|
||||||
|
acknowledged := uint32(987654321)
|
||||||
|
payloadSize := 1460
|
||||||
|
fragmented := false
|
||||||
|
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
|
||||||
|
|
||||||
|
b.Run("SimpleMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(simpleMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConntrackMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ComplexMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerParallel tests the logger under concurrent load
|
||||||
|
func BenchmarkLoggerParallel(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
|
||||||
|
func BenchmarkLoggerBurst(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestLogger() *log.Logger {
|
||||||
|
logrusLogger := logrus.New()
|
||||||
|
logrusLogger.SetOutput(&discard{})
|
||||||
|
logrusLogger.SetLevel(logrus.TraceLevel)
|
||||||
|
return log.NewFromLogrus(logrusLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupLogger(logger *log.Logger) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = logger.Stop(ctx)
|
||||||
|
}
|
||||||
@@ -1,30 +1,45 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Rule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type Rule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
ip net.IP
|
mgmtId []byte
|
||||||
|
ip netip.Addr
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
matchByIP bool
|
matchByIP bool
|
||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
direction firewall.RuleDirection
|
sPort *firewall.Port
|
||||||
sPort uint16
|
dPort *firewall.Port
|
||||||
dPort uint16
|
|
||||||
drop bool
|
drop bool
|
||||||
comment string
|
|
||||||
|
|
||||||
udpHook func([]byte) bool
|
udpHook func([]byte) bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetRuleID returns the rule id
|
// ID returns the rule id
|
||||||
func (r *Rule) GetRuleID() string {
|
func (r *PeerRule) ID() string {
|
||||||
|
return r.id
|
||||||
|
}
|
||||||
|
|
||||||
|
type RouteRule struct {
|
||||||
|
id string
|
||||||
|
mgmtId []byte
|
||||||
|
sources []netip.Prefix
|
||||||
|
destination netip.Prefix
|
||||||
|
proto firewall.Protocol
|
||||||
|
srcPort *firewall.Port
|
||||||
|
dstPort *firewall.Port
|
||||||
|
action firewall.Action
|
||||||
|
}
|
||||||
|
|
||||||
|
// ID returns the rule id
|
||||||
|
func (r *RouteRule) ID() string {
|
||||||
return r.id
|
return r.id
|
||||||
}
|
}
|
||||||
|
|||||||
411
client/firewall/uspfilter/tracer.go
Normal file
411
client/firewall/uspfilter/tracer.go
Normal file
@@ -0,0 +1,411 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/gopacket"
|
||||||
|
"github.com/google/gopacket/layers"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
)
|
||||||
|
|
||||||
|
type PacketStage int
|
||||||
|
|
||||||
|
const (
|
||||||
|
StageReceived PacketStage = iota
|
||||||
|
StageConntrack
|
||||||
|
StagePeerACL
|
||||||
|
StageRouting
|
||||||
|
StageRouteACL
|
||||||
|
StageForwarding
|
||||||
|
StageCompleted
|
||||||
|
)
|
||||||
|
|
||||||
|
const msgProcessingCompleted = "Processing completed"
|
||||||
|
|
||||||
|
func (s PacketStage) String() string {
|
||||||
|
return map[PacketStage]string{
|
||||||
|
StageReceived: "Received",
|
||||||
|
StageConntrack: "Connection Tracking",
|
||||||
|
StagePeerACL: "Peer ACL",
|
||||||
|
StageRouting: "Routing",
|
||||||
|
StageRouteACL: "Route ACL",
|
||||||
|
StageForwarding: "Forwarding",
|
||||||
|
StageCompleted: "Completed",
|
||||||
|
}[s]
|
||||||
|
}
|
||||||
|
|
||||||
|
type ForwarderAction struct {
|
||||||
|
Action string
|
||||||
|
RemoteAddr string
|
||||||
|
Error error
|
||||||
|
}
|
||||||
|
|
||||||
|
type TraceResult struct {
|
||||||
|
Timestamp time.Time
|
||||||
|
Stage PacketStage
|
||||||
|
Message string
|
||||||
|
Allowed bool
|
||||||
|
ForwarderAction *ForwarderAction
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketTrace struct {
|
||||||
|
SourceIP netip.Addr
|
||||||
|
DestinationIP netip.Addr
|
||||||
|
Protocol string
|
||||||
|
SourcePort uint16
|
||||||
|
DestinationPort uint16
|
||||||
|
Direction fw.RuleDirection
|
||||||
|
Results []TraceResult
|
||||||
|
}
|
||||||
|
|
||||||
|
type TCPState struct {
|
||||||
|
SYN bool
|
||||||
|
ACK bool
|
||||||
|
FIN bool
|
||||||
|
RST bool
|
||||||
|
PSH bool
|
||||||
|
URG bool
|
||||||
|
}
|
||||||
|
|
||||||
|
type PacketBuilder struct {
|
||||||
|
SrcIP netip.Addr
|
||||||
|
DstIP netip.Addr
|
||||||
|
Protocol fw.Protocol
|
||||||
|
SrcPort uint16
|
||||||
|
DstPort uint16
|
||||||
|
ICMPType uint8
|
||||||
|
ICMPCode uint8
|
||||||
|
Direction fw.RuleDirection
|
||||||
|
PayloadSize int
|
||||||
|
TCPState *TCPState
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PacketTrace) AddResult(stage PacketStage, message string, allowed bool) {
|
||||||
|
t.Results = append(t.Results, TraceResult{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Stage: stage,
|
||||||
|
Message: message,
|
||||||
|
Allowed: allowed,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *PacketTrace) AddResultWithForwarder(stage PacketStage, message string, allowed bool, action *ForwarderAction) {
|
||||||
|
t.Results = append(t.Results, TraceResult{
|
||||||
|
Timestamp: time.Now(),
|
||||||
|
Stage: stage,
|
||||||
|
Message: message,
|
||||||
|
Allowed: allowed,
|
||||||
|
ForwarderAction: action,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) Build() ([]byte, error) {
|
||||||
|
ip := p.buildIPLayer()
|
||||||
|
pktLayers := []gopacket.SerializableLayer{ip}
|
||||||
|
|
||||||
|
transportLayer, err := p.buildTransportLayer(ip)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
pktLayers = append(pktLayers, transportLayer...)
|
||||||
|
|
||||||
|
if p.PayloadSize > 0 {
|
||||||
|
payload := make([]byte, p.PayloadSize)
|
||||||
|
pktLayers = append(pktLayers, gopacket.Payload(payload))
|
||||||
|
}
|
||||||
|
|
||||||
|
return serializePacket(pktLayers)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
||||||
|
return &layers.IPv4{
|
||||||
|
Version: 4,
|
||||||
|
TTL: 64,
|
||||||
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
|
SrcIP: p.SrcIP.AsSlice(),
|
||||||
|
DstIP: p.DstIP.AsSlice(),
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildTransportLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
switch p.Protocol {
|
||||||
|
case "tcp":
|
||||||
|
return p.buildTCPLayer(ip)
|
||||||
|
case "udp":
|
||||||
|
return p.buildUDPLayer(ip)
|
||||||
|
case "icmp":
|
||||||
|
return p.buildICMPLayer()
|
||||||
|
default:
|
||||||
|
return nil, fmt.Errorf("unsupported protocol: %s", p.Protocol)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildTCPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
tcp := &layers.TCP{
|
||||||
|
SrcPort: layers.TCPPort(p.SrcPort),
|
||||||
|
DstPort: layers.TCPPort(p.DstPort),
|
||||||
|
Window: 65535,
|
||||||
|
SYN: p.TCPState != nil && p.TCPState.SYN,
|
||||||
|
ACK: p.TCPState != nil && p.TCPState.ACK,
|
||||||
|
FIN: p.TCPState != nil && p.TCPState.FIN,
|
||||||
|
RST: p.TCPState != nil && p.TCPState.RST,
|
||||||
|
PSH: p.TCPState != nil && p.TCPState.PSH,
|
||||||
|
URG: p.TCPState != nil && p.TCPState.URG,
|
||||||
|
}
|
||||||
|
if err := tcp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for TCP checksum: %w", err)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{tcp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildUDPLayer(ip *layers.IPv4) ([]gopacket.SerializableLayer, error) {
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(p.SrcPort),
|
||||||
|
DstPort: layers.UDPPort(p.DstPort),
|
||||||
|
}
|
||||||
|
if err := udp.SetNetworkLayerForChecksum(ip); err != nil {
|
||||||
|
return nil, fmt.Errorf("set network layer for UDP checksum: %w", err)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{udp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (p *PacketBuilder) buildICMPLayer() ([]gopacket.SerializableLayer, error) {
|
||||||
|
icmp := &layers.ICMPv4{
|
||||||
|
TypeCode: layers.CreateICMPv4TypeCode(p.ICMPType, p.ICMPCode),
|
||||||
|
}
|
||||||
|
if p.ICMPType == layers.ICMPv4TypeEchoRequest || p.ICMPType == layers.ICMPv4TypeEchoReply {
|
||||||
|
icmp.Id = uint16(1)
|
||||||
|
icmp.Seq = uint16(1)
|
||||||
|
}
|
||||||
|
return []gopacket.SerializableLayer{icmp}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func serializePacket(layers []gopacket.SerializableLayer) ([]byte, error) {
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
if err := gopacket.SerializeLayers(buf, opts, layers...); err != nil {
|
||||||
|
return nil, fmt.Errorf("serialize packet: %w", err)
|
||||||
|
}
|
||||||
|
return buf.Bytes(), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIPProtocolNumber(protocol fw.Protocol) int {
|
||||||
|
switch protocol {
|
||||||
|
case fw.ProtocolTCP:
|
||||||
|
return int(layers.IPProtocolTCP)
|
||||||
|
case fw.ProtocolUDP:
|
||||||
|
return int(layers.IPProtocolUDP)
|
||||||
|
case fw.ProtocolICMP:
|
||||||
|
return int(layers.IPProtocolICMPv4)
|
||||||
|
default:
|
||||||
|
return 0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) TracePacketFromBuilder(builder *PacketBuilder) (*PacketTrace, error) {
|
||||||
|
packetData, err := builder.Build()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("build packet: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.TracePacket(packetData, builder.Direction), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *PacketTrace {
|
||||||
|
|
||||||
|
d := m.decoders.Get().(*decoder)
|
||||||
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
|
trace := &PacketTrace{Direction: direction}
|
||||||
|
|
||||||
|
// Initial packet decoding
|
||||||
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
|
trace.AddResult(StageReceived, fmt.Sprintf("Failed to decode packet: %v", err), false)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
// Extract base packet info
|
||||||
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
|
trace.SourceIP = srcIP
|
||||||
|
trace.DestinationIP = dstIP
|
||||||
|
|
||||||
|
// Determine protocol and ports
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
trace.Protocol = "TCP"
|
||||||
|
trace.SourcePort = uint16(d.tcp.SrcPort)
|
||||||
|
trace.DestinationPort = uint16(d.tcp.DstPort)
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
trace.Protocol = "UDP"
|
||||||
|
trace.SourcePort = uint16(d.udp.SrcPort)
|
||||||
|
trace.DestinationPort = uint16(d.udp.DstPort)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
trace.Protocol = "ICMP"
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageReceived, fmt.Sprintf("Received %s packet: %s:%d -> %s:%d",
|
||||||
|
trace.Protocol, srcIP, trace.SourcePort, dstIP, trace.DestinationPort), true)
|
||||||
|
|
||||||
|
if direction == fw.RuleDirectionOUT {
|
||||||
|
return m.traceOutbound(packetData, trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||||
|
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
|
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if !m.handleRouting(trace) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
if m.nativeRouter.Load() {
|
||||||
|
return m.handleNativeRouter(trace)
|
||||||
|
}
|
||||||
|
|
||||||
|
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
|
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
|
||||||
|
msg := "No existing connection found"
|
||||||
|
if allowed {
|
||||||
|
msg = m.buildConntrackStateMessage(d)
|
||||||
|
trace.AddResult(StageConntrack, msg, true)
|
||||||
|
trace.AddResult(StageCompleted, "Packet allowed by connection tracking", true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
trace.AddResult(StageConntrack, msg, false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
||||||
|
msg := "Matched existing connection state"
|
||||||
|
switch d.decoded[1] {
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
msg += fmt.Sprintf(" (TCP Flags: SYN=%v ACK=%v RST=%v FIN=%v)",
|
||||||
|
flags&conntrack.TCPSyn != 0,
|
||||||
|
flags&conntrack.TCPAck != 0,
|
||||||
|
flags&conntrack.TCPRst != 0,
|
||||||
|
flags&conntrack.TCPFin != 0)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
msg += fmt.Sprintf(" (ICMP ID=%d, Seq=%d)", d.icmp4.Id, d.icmp4.Seq)
|
||||||
|
}
|
||||||
|
return msg
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
|
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||||
|
|
||||||
|
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
|
|
||||||
|
strRuleId := "<no id>"
|
||||||
|
if ruleId != nil {
|
||||||
|
strRuleId = string(ruleId)
|
||||||
|
}
|
||||||
|
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
|
||||||
|
if blocked {
|
||||||
|
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
|
||||||
|
trace.AddResult(StagePeerACL, msg, false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StagePeerACL, msg, true)
|
||||||
|
|
||||||
|
// Handle netstack mode
|
||||||
|
if m.netstack {
|
||||||
|
switch {
|
||||||
|
case !m.localForwarding:
|
||||||
|
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
|
||||||
|
case m.forwarder.Load() != nil:
|
||||||
|
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
default:
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// In normal mode, packets are allowed through for local delivery
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||||
|
if !m.routingEnabled.Load() {
|
||||||
|
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
trace.AddResult(StageRouting, "Routing enabled, checking ACLs", true)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
||||||
|
trace.AddResult(StageRouteACL, "Using native router, skipping ACL checks", true)
|
||||||
|
trace.AddResult(StageForwarding, "Forwarding via native router", true)
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||||
|
proto, _ := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
|
||||||
|
strId := string(id)
|
||||||
|
if id == nil {
|
||||||
|
strId = "<no id>"
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
|
||||||
|
if !allowed {
|
||||||
|
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
|
||||||
|
}
|
||||||
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
|
if allowed && m.forwarder.Load() != nil {
|
||||||
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
|
}
|
||||||
|
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, allowed)
|
||||||
|
return trace
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr string, allowed bool) {
|
||||||
|
fwdAction := &ForwarderAction{
|
||||||
|
Action: action,
|
||||||
|
RemoteAddr: remoteAddr,
|
||||||
|
}
|
||||||
|
trace.AddResultWithForwarder(StageForwarding,
|
||||||
|
fmt.Sprintf("Forwarding to %s", fwdAction.Action), allowed, fwdAction)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
|
// will create or update the connection state
|
||||||
|
dropped := m.processOutgoingHooks(packetData, 0)
|
||||||
|
if dropped {
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
|
} else {
|
||||||
|
trace.AddResult(StageCompleted, "Packet allowed (outgoing)", true)
|
||||||
|
}
|
||||||
|
return trace
|
||||||
|
}
|
||||||
440
client/firewall/uspfilter/tracer_test.go
Normal file
440
client/firewall/uspfilter/tracer_test.go
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
|
||||||
|
t.Logf("Trace results: %v", trace.Results)
|
||||||
|
actualStages := make([]PacketStage, 0, len(trace.Results))
|
||||||
|
for _, result := range trace.Results {
|
||||||
|
actualStages = append(actualStages, result.Stage)
|
||||||
|
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
|
||||||
|
require.NotEmpty(t, trace.Results, "Trace should have results")
|
||||||
|
lastResult := trace.Results[len(trace.Results)-1]
|
||||||
|
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
|
||||||
|
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTracePacket(t *testing.T) {
|
||||||
|
setupTracerTest := func(statefulMode bool) *Manager {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if !statefulMode {
|
||||||
|
m.stateful = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
builder := &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: protocol,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
|
||||||
|
if protocol == "tcp" {
|
||||||
|
builder.TCPState = &TCPState{SYN: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder
|
||||||
|
}
|
||||||
|
|
||||||
|
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
return &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: "icmp",
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setup func(*Manager)
|
||||||
|
packetBuilder func() *PacketBuilder
|
||||||
|
expectedStages []PacketStage
|
||||||
|
expectedAllow bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = true
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithoutForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_NativeRouter",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(true)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_RoutingDisabled",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(false)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ConnectionTracking_Hit",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.100")
|
||||||
|
dstIP := netip.MustParseAddr("1.1.1.1")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
|
||||||
|
pb.TCPState = &TCPState{SYN: true, ACK: true}
|
||||||
|
return pb
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OutboundTraffic",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPEchoRequest",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPDestinationUnreachable",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithoutHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolUDP
|
||||||
|
port := &fw.Port{Values: []uint16{53}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
hookFunc := func([]byte) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "StatefulDisabled_NoTracking",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.stateful = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
m := setupTracerTest(true)
|
||||||
|
|
||||||
|
tc.setup(m)
|
||||||
|
|
||||||
|
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||||
|
"100.10.0.100 should be recognized as a local IP")
|
||||||
|
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
||||||
|
"172.17.0.2 should not be recognized as a local IP")
|
||||||
|
|
||||||
|
pb := tc.packetBuilder()
|
||||||
|
|
||||||
|
trace, err := m.TracePacketFromBuilder(pb)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
verifyTraceStages(t, trace, tc.expectedStages)
|
||||||
|
verifyFinalDisposition(t, trace, tc.expectedAllow)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
1062
client/firewall/uspfilter/uspfilter_bench_test.go
Normal file
1062
client/firewall/uspfilter/uspfilter_bench_test.go
Normal file
File diff suppressed because it is too large
Load Diff
1017
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
1017
client/firewall/uspfilter/uspfilter_filter_test.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,23 +1,50 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
|
|
||||||
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(context.Background(), nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
AddressFunc func() iface.WGAddress
|
AddressFunc func() wgaddr.Address
|
||||||
|
GetWGDeviceFunc func() *wgdevice.Device
|
||||||
|
GetDeviceFunc func() *device.FilteredDevice
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) GetWGDevice() *wgdevice.Device {
|
||||||
|
if i.GetWGDeviceFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return i.GetWGDeviceFunc()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (i *IFaceMock) GetDevice() *device.FilteredDevice {
|
||||||
|
if i.GetDeviceFunc == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
return i.GetDeviceFunc()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
||||||
@@ -27,9 +54,9 @@ func (i *IFaceMock) SetFilter(iface device.PacketFilter) error {
|
|||||||
return i.SetFilterFunc(iface)
|
return i.SetFilterFunc(iface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (i *IFaceMock) Address() iface.WGAddress {
|
func (i *IFaceMock) Address() wgaddr.Address {
|
||||||
if i.AddressFunc == nil {
|
if i.AddressFunc == nil {
|
||||||
return iface.WGAddress{}
|
return wgaddr.Address{}
|
||||||
}
|
}
|
||||||
return i.AddressFunc()
|
return i.AddressFunc()
|
||||||
}
|
}
|
||||||
@@ -39,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -59,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -67,12 +94,10 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []int{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -94,48 +119,25 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := netip.MustParseAddr("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []int{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip = net.ParseIP("192.168.1.1")
|
|
||||||
proto = fw.ProtocolTCP
|
|
||||||
port = &fw.Port{Values: []int{80}}
|
|
||||||
direction = fw.RuleDirectionIN
|
|
||||||
action = fw.ActionDrop
|
|
||||||
comment = "Test rule 2"
|
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule {
|
|
||||||
err = m.DeletePeerRule(r)
|
|
||||||
if err != nil {
|
|
||||||
t.Errorf("failed to delete rule: %v", err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; !ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -149,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.GetRuleID()]; ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -160,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
in bool
|
in bool
|
||||||
expDir fw.RuleDirection
|
expDir fw.RuleDirection
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
dPort uint16
|
dPort uint16
|
||||||
hook func([]byte) bool
|
hook func([]byte) bool
|
||||||
expectedID string
|
expectedID string
|
||||||
@@ -169,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Outgoing UDP Packet Hook",
|
name: "Test Outgoing UDP Packet Hook",
|
||||||
in: false,
|
in: false,
|
||||||
expDir: fw.RuleDirectionOUT,
|
expDir: fw.RuleDirectionOUT,
|
||||||
ip: net.IPv4(10, 168, 0, 1),
|
ip: netip.MustParseAddr("10.168.0.1"),
|
||||||
dPort: 8000,
|
dPort: 8000,
|
||||||
hook: func([]byte) bool { return true },
|
hook: func([]byte) bool { return true },
|
||||||
},
|
},
|
||||||
@@ -177,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Incoming UDP Packet Hook",
|
name: "Test Incoming UDP Packet Hook",
|
||||||
in: true,
|
in: true,
|
||||||
expDir: fw.RuleDirectionIN,
|
expDir: fw.RuleDirectionIN,
|
||||||
ip: net.IPv6loopback,
|
ip: netip.MustParseAddr("::1"),
|
||||||
dPort: 9000,
|
dPort: 9000,
|
||||||
hook: func([]byte) bool { return false },
|
hook: func([]byte) bool { return false },
|
||||||
},
|
},
|
||||||
@@ -185,20 +187,20 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager := &Manager{
|
manager, err := Create(&IFaceMock{
|
||||||
incomingRules: map[string]RuleSet{},
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
outgoingRules: map[string]RuleSet{},
|
}, false, flowLogger)
|
||||||
}
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule Rule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
for _, rule := range manager.incomingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -206,27 +208,23 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if tt.dPort != addedRule.dPort {
|
if tt.dPort != addedRule.dPort.Values[0] {
|
||||||
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort)
|
t.Errorf("expected dPort %d, got %d", tt.dPort, addedRule.dPort.Values[0])
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if layers.LayerTypeUDP != addedRule.protoLayer {
|
if layers.LayerTypeUDP != addedRule.protoLayer {
|
||||||
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
t.Errorf("expected protoLayer %s, got %s", layers.LayerTypeUDP, addedRule.protoLayer)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if tt.expDir != addedRule.direction {
|
|
||||||
t.Errorf("expected direction %d, got %d", tt.expDir, addedRule.direction)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
if addedRule.udpHook == nil {
|
if addedRule.udpHook == nil {
|
||||||
t.Errorf("expected udpHook to be set")
|
t.Errorf("expected udpHook to be set")
|
||||||
return
|
return
|
||||||
@@ -240,7 +238,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -248,18 +246,16 @@ func TestManagerReset(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := net.ParseIP("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []int{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, direction, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
err = m.Reset(nil)
|
err = m.Close(nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
t.Errorf("failed to reset Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -273,9 +269,18 @@ func TestManagerReset(t *testing.T) {
|
|||||||
func TestNotMatchByIP(t *testing.T) {
|
func TestNotMatchByIP(t *testing.T) {
|
||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -287,11 +292,9 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
|
|
||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
direction := fw.RuleDirectionOUT
|
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, direction, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -313,7 +316,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
t.Errorf("failed to set network layer for checksum: %v", err)
|
t.Errorf("failed to set network layer for checksum: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
payload := gopacket.Payload([]byte("test"))
|
payload := gopacket.Payload("test")
|
||||||
|
|
||||||
buf := gopacket.NewSerializeBuffer()
|
buf := gopacket.NewSerializeBuffer()
|
||||||
opts := gopacket.SerializeOptions{
|
opts := gopacket.SerializeOptions{
|
||||||
@@ -325,12 +328,12 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes(), m.outgoingRules, false) {
|
if m.dropFilter(buf.Bytes(), 0) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = m.Reset(nil); err != nil {
|
if err = m.Close(nil); err != nil {
|
||||||
t.Errorf("failed to reset Manager: %v", err)
|
t.Errorf("failed to reset Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -344,14 +347,17 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface)
|
manager, err := Create(iface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
// Add a UDP packet hook
|
// Add a UDP packet hook
|
||||||
hookFunc := func(data []byte) bool { return true }
|
hookFunc := func(data []byte) bool { return true }
|
||||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
@@ -384,6 +390,88 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
|
manager.udpTracker.Close()
|
||||||
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
manager.decoders = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
return d
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
hookCalled := false
|
||||||
|
hookID := manager.AddUDPPacketHook(
|
||||||
|
false,
|
||||||
|
netip.MustParseAddr("100.10.0.100"),
|
||||||
|
53,
|
||||||
|
func([]byte) bool {
|
||||||
|
hookCalled = true
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
)
|
||||||
|
require.NotEmpty(t, hookID)
|
||||||
|
|
||||||
|
// Create test UDP packet
|
||||||
|
ipv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: net.ParseIP("100.10.0.1"),
|
||||||
|
DstIP: net.ParseIP("100.10.0.100"),
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
udp := &layers.UDP{
|
||||||
|
SrcPort: 51334,
|
||||||
|
DstPort: 53,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = udp.SetNetworkLayerForChecksum(ipv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
payload := gopacket.Payload("test")
|
||||||
|
|
||||||
|
buf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipv4, udp, payload)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test hook gets called
|
||||||
|
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
|
require.True(t, result)
|
||||||
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
|
// Test non-UDP packet is ignored
|
||||||
|
ipv4.Protocol = layers.IPProtocolTCP
|
||||||
|
buf = gopacket.NewSerializeBuffer()
|
||||||
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
|
require.False(t, result)
|
||||||
|
}
|
||||||
|
|
||||||
func TestUSPFilterCreatePerformance(t *testing.T) {
|
func TestUSPFilterCreatePerformance(t *testing.T) {
|
||||||
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
for _, testMax := range []int{10, 20, 30, 40, 50, 60, 70, 80, 90, 100, 200, 300, 400, 500, 600, 700, 800, 900, 1000} {
|
||||||
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
t.Run(fmt.Sprintf("Testing %d rules", testMax), func(t *testing.T) {
|
||||||
@@ -391,12 +479,12 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := manager.Reset(nil); err != nil {
|
if err := manager.Close(nil); err != nil {
|
||||||
t.Errorf("clear the manager state: %v", err)
|
t.Errorf("clear the manager state: %v", err)
|
||||||
}
|
}
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
@@ -405,12 +493,8 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ip := net.ParseIP("10.20.0.100")
|
ip := net.ParseIP("10.20.0.100")
|
||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []int{1000 + i}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
if i%2 == 0 {
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionOUT, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
} else {
|
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.RuleDirectionIN, fw.ActionAccept, "", "accept HTTP traffic")
|
|
||||||
}
|
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
@@ -418,3 +502,213 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
|
manager, err := Create(&IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
manager.wgNetwork = &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
}
|
||||||
|
|
||||||
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
|
manager.decoders = sync.Pool{
|
||||||
|
New: func() any {
|
||||||
|
d := &decoder{
|
||||||
|
decoded: []gopacket.LayerType{},
|
||||||
|
}
|
||||||
|
d.parser = gopacket.NewDecodingLayerParser(
|
||||||
|
layers.LayerTypeIPv4,
|
||||||
|
&d.eth, &d.ip4, &d.ip6, &d.icmp4, &d.icmp6, &d.tcp, &d.udp,
|
||||||
|
)
|
||||||
|
d.parser.IgnoreUnsupported = true
|
||||||
|
return d
|
||||||
|
},
|
||||||
|
}
|
||||||
|
defer func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
}()
|
||||||
|
|
||||||
|
// Set up packet parameters
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.10.0.100")
|
||||||
|
srcPort := uint16(51334)
|
||||||
|
dstPort := uint16(53)
|
||||||
|
|
||||||
|
// Create outbound packet
|
||||||
|
outboundIPv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: srcIP.AsSlice(),
|
||||||
|
DstIP: dstIP.AsSlice(),
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
outboundUDP := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(srcPort),
|
||||||
|
DstPort: layers.UDPPort(dstPort),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = outboundUDP.SetNetworkLayerForChecksum(outboundIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
outboundBuf := gopacket.NewSerializeBuffer()
|
||||||
|
opts := gopacket.SerializeOptions{
|
||||||
|
ComputeChecksums: true,
|
||||||
|
FixLengths: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = gopacket.SerializeLayers(outboundBuf, opts,
|
||||||
|
outboundIPv4,
|
||||||
|
outboundUDP,
|
||||||
|
gopacket.Payload("test"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Process outbound packet and verify connection tracking
|
||||||
|
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||||
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
|
// Verify connection was tracked
|
||||||
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||||
|
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||||
|
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
|
||||||
|
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||||
|
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||||
|
|
||||||
|
// Create valid inbound response packet
|
||||||
|
inboundIPv4 := &layers.IPv4{
|
||||||
|
TTL: 64,
|
||||||
|
Version: 4,
|
||||||
|
SrcIP: dstIP.AsSlice(), // Original destination is now source
|
||||||
|
DstIP: srcIP.AsSlice(), // Original source is now destination
|
||||||
|
Protocol: layers.IPProtocolUDP,
|
||||||
|
}
|
||||||
|
inboundUDP := &layers.UDP{
|
||||||
|
SrcPort: layers.UDPPort(dstPort), // Original destination port is now source
|
||||||
|
DstPort: layers.UDPPort(srcPort), // Original source port is now destination
|
||||||
|
}
|
||||||
|
|
||||||
|
err = inboundUDP.SetNetworkLayerForChecksum(inboundIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
inboundBuf := gopacket.NewSerializeBuffer()
|
||||||
|
err = gopacket.SerializeLayers(inboundBuf, opts,
|
||||||
|
inboundIPv4,
|
||||||
|
inboundUDP,
|
||||||
|
gopacket.Payload("response"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
// Test roundtrip response handling over time
|
||||||
|
checkPoints := []struct {
|
||||||
|
sleep time.Duration
|
||||||
|
shouldAllow bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
sleep: 0,
|
||||||
|
shouldAllow: true,
|
||||||
|
description: "Immediate response should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sleep: 50 * time.Millisecond,
|
||||||
|
shouldAllow: true,
|
||||||
|
description: "Response within timeout should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
sleep: 100 * time.Millisecond,
|
||||||
|
shouldAllow: true,
|
||||||
|
description: "Response at half timeout should be allowed",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
// tracker hasn't updated conn for 250ms -> greater than 200ms timeout
|
||||||
|
sleep: 250 * time.Millisecond,
|
||||||
|
shouldAllow: false,
|
||||||
|
description: "Response after timeout should be dropped",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, cp := range checkPoints {
|
||||||
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
|
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||||
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
|
// If the connection should still be valid, verify it exists
|
||||||
|
if cp.shouldAllow {
|
||||||
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
require.True(t, exists, "Connection should still exist during valid window")
|
||||||
|
require.True(t, time.Since(conn.GetLastSeen()) < manager.udpTracker.Timeout(),
|
||||||
|
"LastSeen should be updated for valid responses")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test invalid response packets (while connection is expired)
|
||||||
|
invalidCases := []struct {
|
||||||
|
name string
|
||||||
|
modifyFunc func(*layers.IPv4, *layers.UDP)
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "wrong source IP",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
ip.SrcIP = net.ParseIP("100.10.0.101")
|
||||||
|
},
|
||||||
|
description: "Response from wrong IP should be dropped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong destination IP",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
ip.DstIP = net.ParseIP("100.10.0.2")
|
||||||
|
},
|
||||||
|
description: "Response to wrong IP should be dropped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong source port",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
udp.SrcPort = 54
|
||||||
|
},
|
||||||
|
description: "Response from wrong port should be dropped",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong destination port",
|
||||||
|
modifyFunc: func(ip *layers.IPv4, udp *layers.UDP) {
|
||||||
|
udp.DstPort = 51335
|
||||||
|
},
|
||||||
|
description: "Response to wrong port should be dropped",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
// Create a new outbound connection for invalid tests
|
||||||
|
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||||
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
|
for _, tc := range invalidCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
testIPv4 := *inboundIPv4
|
||||||
|
testUDP := *inboundUDP
|
||||||
|
|
||||||
|
tc.modifyFunc(&testIPv4, &testUDP)
|
||||||
|
|
||||||
|
err = testUDP.SetNetworkLayerForChecksum(&testIPv4)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
testBuf := gopacket.NewSerializeBuffer()
|
||||||
|
err = gopacket.SerializeLayers(testBuf, opts,
|
||||||
|
&testIPv4,
|
||||||
|
&testUDP,
|
||||||
|
gopacket.Payload("response"),
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify the invalid packet is dropped
|
||||||
|
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||||
|
require.True(t, drop, tc.description)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
12
client/iface/bind/control_android.go
Normal file
12
client/iface/bind/control_android.go
Normal file
@@ -0,0 +1,12 @@
|
|||||||
|
package bind
|
||||||
|
|
||||||
|
import (
|
||||||
|
wireguard "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
// ControlFns is not thread safe and should only be modified during init.
|
||||||
|
*wireguard.ControlFns = append(*wireguard.ControlFns, nbnet.ControlProtectSocket)
|
||||||
|
}
|
||||||
@@ -5,7 +5,6 @@ import (
|
|||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
"strings"
|
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/pion/stun/v2"
|
"github.com/pion/stun/v2"
|
||||||
@@ -14,6 +13,8 @@ import (
|
|||||||
"golang.org/x/net/ipv4"
|
"golang.org/x/net/ipv4"
|
||||||
"golang.org/x/net/ipv6"
|
"golang.org/x/net/ipv6"
|
||||||
wgConn "golang.zx2c4.com/wireguard/conn"
|
wgConn "golang.zx2c4.com/wireguard/conn"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RecvMessage struct {
|
type RecvMessage struct {
|
||||||
@@ -52,9 +53,10 @@ type ICEBind struct {
|
|||||||
|
|
||||||
muUDPMux sync.Mutex
|
muUDPMux sync.Mutex
|
||||||
udpMux *UniversalUDPMuxDefault
|
udpMux *UniversalUDPMuxDefault
|
||||||
|
address wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
func NewICEBind(transportNet transport.Net, filterFn FilterFn, address wgaddr.Address) *ICEBind {
|
||||||
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
b, _ := wgConn.NewStdNetBind().(*wgConn.StdNetBind)
|
||||||
ib := &ICEBind{
|
ib := &ICEBind{
|
||||||
StdNetBind: b,
|
StdNetBind: b,
|
||||||
@@ -64,6 +66,7 @@ func NewICEBind(transportNet transport.Net, filterFn FilterFn) *ICEBind {
|
|||||||
endpoints: make(map[netip.Addr]net.Conn),
|
endpoints: make(map[netip.Addr]net.Conn),
|
||||||
closedChan: make(chan struct{}),
|
closedChan: make(chan struct{}),
|
||||||
closed: true,
|
closed: true,
|
||||||
|
address: address,
|
||||||
}
|
}
|
||||||
|
|
||||||
rc := receiverCreator{
|
rc := receiverCreator{
|
||||||
@@ -108,35 +111,17 @@ func (s *ICEBind) GetICEMux() (*UniversalUDPMuxDefault, error) {
|
|||||||
return s.udpMux, nil
|
return s.udpMux, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *ICEBind) SetEndpoint(peerAddress *net.UDPAddr, conn net.Conn) (*net.UDPAddr, error) {
|
func (b *ICEBind) SetEndpoint(fakeIP netip.Addr, conn net.Conn) {
|
||||||
fakeUDPAddr, err := fakeAddress(peerAddress)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// force IPv4
|
|
||||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
|
||||||
if !ok {
|
|
||||||
return nil, fmt.Errorf("failed to convert IP to netip.Addr")
|
|
||||||
}
|
|
||||||
|
|
||||||
b.endpointsMu.Lock()
|
b.endpointsMu.Lock()
|
||||||
b.endpoints[fakeAddr] = conn
|
b.endpoints[fakeIP] = conn
|
||||||
b.endpointsMu.Unlock()
|
b.endpointsMu.Unlock()
|
||||||
|
|
||||||
return fakeUDPAddr, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *ICEBind) RemoveEndpoint(fakeUDPAddr *net.UDPAddr) {
|
func (b *ICEBind) RemoveEndpoint(fakeIP netip.Addr) {
|
||||||
fakeAddr, ok := netip.AddrFromSlice(fakeUDPAddr.IP.To4())
|
|
||||||
if !ok {
|
|
||||||
log.Warnf("failed to convert IP to netip.Addr")
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
b.endpointsMu.Lock()
|
b.endpointsMu.Lock()
|
||||||
defer b.endpointsMu.Unlock()
|
defer b.endpointsMu.Unlock()
|
||||||
delete(b.endpoints, fakeAddr)
|
|
||||||
|
delete(b.endpoints, fakeIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
func (b *ICEBind) Send(bufs [][]byte, ep wgConn.Endpoint) error {
|
||||||
@@ -161,9 +146,10 @@ func (s *ICEBind) createIPv4ReceiverFn(pc *ipv4.PacketConn, conn *net.UDPConn, r
|
|||||||
|
|
||||||
s.udpMux = NewUniversalUDPMuxDefault(
|
s.udpMux = NewUniversalUDPMuxDefault(
|
||||||
UniversalUDPMuxParams{
|
UniversalUDPMuxParams{
|
||||||
UDPConn: conn,
|
UDPConn: conn,
|
||||||
Net: s.transportNet,
|
Net: s.transportNet,
|
||||||
FilterFn: s.filterFn,
|
FilterFn: s.filterFn,
|
||||||
|
WGAddress: s.address,
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
return func(bufs [][]byte, sizes []int, eps []wgConn.Endpoint) (n int, err error) {
|
||||||
@@ -275,21 +261,6 @@ func (c *ICEBind) receiveRelayed(buffs [][]byte, sizes []int, eps []wgConn.Endpo
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// fakeAddress returns a fake address that is used to as an identifier for the peer.
|
|
||||||
// The fake address is in the format of 127.1.x.x where x.x is the last two octets of the peer address.
|
|
||||||
func fakeAddress(peerAddress *net.UDPAddr) (*net.UDPAddr, error) {
|
|
||||||
octets := strings.Split(peerAddress.IP.String(), ".")
|
|
||||||
if len(octets) != 4 {
|
|
||||||
return nil, fmt.Errorf("invalid IP format")
|
|
||||||
}
|
|
||||||
|
|
||||||
newAddr := &net.UDPAddr{
|
|
||||||
IP: net.ParseIP(fmt.Sprintf("127.1.%s.%s", octets[2], octets[3])),
|
|
||||||
Port: peerAddress.Port,
|
|
||||||
}
|
|
||||||
return newAddr, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
func getMessages(msgsPool *sync.Pool) *[]ipv6.Message {
|
||||||
return msgsPool.Get().(*[]ipv6.Message)
|
return msgsPool.Get().(*[]ipv6.Message)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"slices"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
@@ -152,45 +153,7 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
|||||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
||||||
}
|
}
|
||||||
|
|
||||||
var localAddrsForUnspecified []net.Addr
|
mux := &UDPMuxDefault{
|
||||||
if addr, ok := params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
|
||||||
params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", params.UDPConn.LocalAddr())
|
|
||||||
} else if ok && addr.IP.IsUnspecified() {
|
|
||||||
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
|
||||||
// it will break the applications that are already using unspecified UDP connection
|
|
||||||
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
|
||||||
params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
|
||||||
var networks []ice.NetworkType
|
|
||||||
switch {
|
|
||||||
case addr.IP.To4() != nil:
|
|
||||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
|
||||||
|
|
||||||
case addr.IP.To16() != nil:
|
|
||||||
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
|
||||||
|
|
||||||
default:
|
|
||||||
params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", params.UDPConn.LocalAddr())
|
|
||||||
}
|
|
||||||
if len(networks) > 0 {
|
|
||||||
if params.Net == nil {
|
|
||||||
var err error
|
|
||||||
if params.Net, err = stdnet.NewNet(); err != nil {
|
|
||||||
params.Logger.Errorf("failed to get create network: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
ips, err := localInterfaces(params.Net, params.InterfaceFilter, nil, networks, true)
|
|
||||||
if err == nil {
|
|
||||||
for _, ip := range ips {
|
|
||||||
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
|
||||||
}
|
|
||||||
} else {
|
|
||||||
params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
return &UDPMuxDefault{
|
|
||||||
addressMap: map[string][]*udpMuxedConn{},
|
addressMap: map[string][]*udpMuxedConn{},
|
||||||
params: params,
|
params: params,
|
||||||
connsIPv4: make(map[string]*udpMuxedConn),
|
connsIPv4: make(map[string]*udpMuxedConn),
|
||||||
@@ -202,8 +165,55 @@ func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
|||||||
return newBufferHolder(receiveMTU + maxAddrSize)
|
return newBufferHolder(receiveMTU + maxAddrSize)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
localAddrsForUnspecified: localAddrsForUnspecified,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
mux.updateLocalAddresses()
|
||||||
|
return mux
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *UDPMuxDefault) updateLocalAddresses() {
|
||||||
|
var localAddrsForUnspecified []net.Addr
|
||||||
|
if addr, ok := m.params.UDPConn.LocalAddr().(*net.UDPAddr); !ok {
|
||||||
|
m.params.Logger.Errorf("LocalAddr is not a net.UDPAddr, got %T", m.params.UDPConn.LocalAddr())
|
||||||
|
} else if ok && addr.IP.IsUnspecified() {
|
||||||
|
// For unspecified addresses, the correct behavior is to return errListenUnspecified, but
|
||||||
|
// it will break the applications that are already using unspecified UDP connection
|
||||||
|
// with UDPMuxDefault, so print a warn log and create a local address list for mux.
|
||||||
|
m.params.Logger.Warn("UDPMuxDefault should not listening on unspecified address, use NewMultiUDPMuxFromPort instead")
|
||||||
|
var networks []ice.NetworkType
|
||||||
|
switch {
|
||||||
|
|
||||||
|
case addr.IP.To16() != nil:
|
||||||
|
networks = []ice.NetworkType{ice.NetworkTypeUDP4, ice.NetworkTypeUDP6}
|
||||||
|
|
||||||
|
case addr.IP.To4() != nil:
|
||||||
|
networks = []ice.NetworkType{ice.NetworkTypeUDP4}
|
||||||
|
|
||||||
|
default:
|
||||||
|
m.params.Logger.Errorf("LocalAddr expected IPV4 or IPV6, got %T", m.params.UDPConn.LocalAddr())
|
||||||
|
}
|
||||||
|
if len(networks) > 0 {
|
||||||
|
if m.params.Net == nil {
|
||||||
|
var err error
|
||||||
|
if m.params.Net, err = stdnet.NewNet(); err != nil {
|
||||||
|
m.params.Logger.Errorf("failed to get create network: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ips, err := localInterfaces(m.params.Net, m.params.InterfaceFilter, nil, networks, true)
|
||||||
|
if err == nil {
|
||||||
|
for _, ip := range ips {
|
||||||
|
localAddrsForUnspecified = append(localAddrsForUnspecified, &net.UDPAddr{IP: ip, Port: addr.Port})
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
m.params.Logger.Errorf("failed to get local interfaces for unspecified addr: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
m.localAddrsForUnspecified = localAddrsForUnspecified
|
||||||
|
m.mu.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// LocalAddr returns the listening address of this UDPMuxDefault
|
// LocalAddr returns the listening address of this UDPMuxDefault
|
||||||
@@ -213,8 +223,12 @@ func (m *UDPMuxDefault) LocalAddr() net.Addr {
|
|||||||
|
|
||||||
// GetListenAddresses returns the list of addresses that this mux is listening on
|
// GetListenAddresses returns the list of addresses that this mux is listening on
|
||||||
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
||||||
|
m.updateLocalAddresses()
|
||||||
|
|
||||||
|
m.mu.Lock()
|
||||||
|
defer m.mu.Unlock()
|
||||||
if len(m.localAddrsForUnspecified) > 0 {
|
if len(m.localAddrsForUnspecified) > 0 {
|
||||||
return m.localAddrsForUnspecified
|
return slices.Clone(m.localAddrsForUnspecified)
|
||||||
}
|
}
|
||||||
|
|
||||||
return []net.Addr{m.LocalAddr()}
|
return []net.Addr{m.LocalAddr()}
|
||||||
@@ -224,7 +238,10 @@ func (m *UDPMuxDefault) GetListenAddresses() []net.Addr {
|
|||||||
// creates the connection if an existing one can't be found
|
// creates the connection if an existing one can't be found
|
||||||
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
func (m *UDPMuxDefault) GetConn(ufrag string, addr net.Addr) (net.PacketConn, error) {
|
||||||
// don't check addr for mux using unspecified address
|
// don't check addr for mux using unspecified address
|
||||||
if len(m.localAddrsForUnspecified) == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
m.mu.Lock()
|
||||||
|
lenLocalAddrs := len(m.localAddrsForUnspecified)
|
||||||
|
m.mu.Unlock()
|
||||||
|
if lenLocalAddrs == 0 && m.params.UDPConn.LocalAddr().String() != addr.String() {
|
||||||
return nil, fmt.Errorf("invalid address %s", addr.String())
|
return nil, fmt.Errorf("invalid address %s", addr.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -17,6 +17,8 @@ import (
|
|||||||
"github.com/pion/logging"
|
"github.com/pion/logging"
|
||||||
"github.com/pion/stun/v2"
|
"github.com/pion/stun/v2"
|
||||||
"github.com/pion/transport/v3"
|
"github.com/pion/transport/v3"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
)
|
)
|
||||||
|
|
||||||
// FilterFn is a function that filters out candidates based on the address.
|
// FilterFn is a function that filters out candidates based on the address.
|
||||||
@@ -41,6 +43,7 @@ type UniversalUDPMuxParams struct {
|
|||||||
XORMappedAddrCacheTTL time.Duration
|
XORMappedAddrCacheTTL time.Duration
|
||||||
Net transport.Net
|
Net transport.Net
|
||||||
FilterFn FilterFn
|
FilterFn FilterFn
|
||||||
|
WGAddress wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||||
@@ -64,6 +67,7 @@ func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDef
|
|||||||
mux: m,
|
mux: m,
|
||||||
logger: params.Logger,
|
logger: params.Logger,
|
||||||
filterFn: params.FilterFn,
|
filterFn: params.FilterFn,
|
||||||
|
address: params.WGAddress,
|
||||||
}
|
}
|
||||||
|
|
||||||
// embed UDPMux
|
// embed UDPMux
|
||||||
@@ -118,6 +122,7 @@ type udpConn struct {
|
|||||||
filterFn FilterFn
|
filterFn FilterFn
|
||||||
// TODO: reset cache on route changes
|
// TODO: reset cache on route changes
|
||||||
addrCache sync.Map
|
addrCache sync.Map
|
||||||
|
address wgaddr.Address
|
||||||
}
|
}
|
||||||
|
|
||||||
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
func (u *udpConn) WriteTo(b []byte, addr net.Addr) (int, error) {
|
||||||
@@ -159,6 +164,11 @@ func (u *udpConn) performFilterCheck(addr net.Addr) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if u.address.Network.Contains(a.AsSlice()) {
|
||||||
|
log.Warnf("Address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
|
return fmt.Errorf("address %s is part of the NetBird network %s, refusing to write", addr, u.address)
|
||||||
|
}
|
||||||
|
|
||||||
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
if isRouted, prefix, err := u.filterFn(a); err != nil {
|
||||||
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
log.Errorf("Failed to check if address %s is routed: %v", addr, err)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
@@ -43,13 +43,7 @@ func (c *KernelConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
// parse allowed ips
|
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIps)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -58,7 +52,7 @@ func (c *KernelConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAli
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||||
AllowedIPs: []net.IPNet{*ipNet},
|
AllowedIPs: allowedIps,
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
PresharedKey: preSharedKey,
|
PresharedKey: preSharedKey,
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
// WgInterfaceDefault is a default interface name of Netbird
|
||||||
const WgInterfaceDefault = "wt0"
|
const WgInterfaceDefault = "wt0"
|
||||||
|
|||||||
@@ -2,5 +2,5 @@
|
|||||||
|
|
||||||
package configurer
|
package configurer
|
||||||
|
|
||||||
// WgInterfaceDefault is a default interface name of Wiretrustee
|
// WgInterfaceDefault is a default interface name of Netbird
|
||||||
const WgInterfaceDefault = "utun100"
|
const WgInterfaceDefault = "utun100"
|
||||||
|
|||||||
@@ -52,13 +52,7 @@ func (c *WGUSPConfigurer) ConfigureInterface(privateKey string, port int) error
|
|||||||
return c.device.IpcSet(toWgUserspaceString(config))
|
return c.device.IpcSet(toWgUserspaceString(config))
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps []net.IPNet, keepAlive time.Duration, endpoint *net.UDPAddr, preSharedKey *wgtypes.Key) error {
|
||||||
// parse allowed ips
|
|
||||||
_, ipNet, err := net.ParseCIDR(allowedIps)
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
peerKeyParsed, err := wgtypes.ParseKey(peerKey)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
@@ -67,7 +61,7 @@ func (c *WGUSPConfigurer) UpdatePeer(peerKey string, allowedIps string, keepAliv
|
|||||||
PublicKey: peerKeyParsed,
|
PublicKey: peerKeyParsed,
|
||||||
ReplaceAllowedIPs: false,
|
ReplaceAllowedIPs: false,
|
||||||
// don't replace allowed ips, wg will handle duplicated peer IP
|
// don't replace allowed ips, wg will handle duplicated peer IP
|
||||||
AllowedIPs: []net.IPNet{*ipNet},
|
AllowedIPs: allowedIps,
|
||||||
PersistentKeepaliveInterval: &keepAlive,
|
PersistentKeepaliveInterval: &keepAlive,
|
||||||
PresharedKey: preSharedKey,
|
PresharedKey: preSharedKey,
|
||||||
Endpoint: endpoint,
|
Endpoint: endpoint,
|
||||||
@@ -362,7 +356,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if runtime.GOOS == "linux" && !nbnet.CustomRoutingDisabled() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.NetbirdFwmark
|
return nbnet.NetbirdFwmark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user