mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 00:36:38 +00:00
Compare commits
98 Commits
test/netwo
...
v0.43.0
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
3cf87b6846 | ||
|
|
4fe4c2054d | ||
|
|
38ada44a0e | ||
|
|
dbf81a145e | ||
|
|
39483f8ca8 | ||
|
|
c0eaea938e | ||
|
|
ef8b8a2891 | ||
|
|
2817f62c13 | ||
|
|
4a9049566a | ||
|
|
85f92f8321 | ||
|
|
714beb6e3b | ||
|
|
400b9fca32 | ||
|
|
4013298e22 | ||
|
|
312bfd9bd7 | ||
|
|
8db05838ca | ||
|
|
c69df13515 | ||
|
|
986eb8c1e0 | ||
|
|
197761ba4d | ||
|
|
f74ea64c7b | ||
|
|
3b7b9d25bc | ||
|
|
1a6d6b3109 | ||
|
|
f686615876 | ||
|
|
a4311f574d | ||
|
|
0bb8eae903 | ||
|
|
e0b33d325d | ||
|
|
c38e07d89a | ||
|
|
a37368fff4 | ||
|
|
0c93bd3d06 | ||
|
|
a675531b5c | ||
|
|
7cb366bc7d | ||
|
|
a354004564 | ||
|
|
75bdd47dfb | ||
|
|
b165f63327 | ||
|
|
51bb52cdf5 | ||
|
|
4134b857b4 | ||
|
|
7839d2c169 | ||
|
|
b9f82e2f8a | ||
|
|
fd2a21c65d | ||
|
|
82d982b0ab | ||
|
|
9e24fe7701 | ||
|
|
e470701b80 | ||
|
|
e3ce026355 | ||
|
|
5ea2806663 | ||
|
|
d6b0673580 | ||
|
|
14913cfa7a | ||
|
|
03f600b576 | ||
|
|
192c97aa63 | ||
|
|
4db78db49a | ||
|
|
87e600a4f3 | ||
|
|
6162aeb82d | ||
|
|
1ba1e092ce | ||
|
|
86dbb4ee4f | ||
|
|
4af177215f | ||
|
|
df9c1b9883 | ||
|
|
5752bb78f2 | ||
|
|
fbd783ad58 | ||
|
|
80702b9323 | ||
|
|
09243a0fe0 | ||
|
|
3658215747 | ||
|
|
48ffec95dd | ||
|
|
cbec7bda80 | ||
|
|
21464ac770 | ||
|
|
ed5647028a | ||
|
|
29a6e5be71 | ||
|
|
6124e3b937 | ||
|
|
50f5cc48cd | ||
|
|
101cce27f2 | ||
|
|
a4f04f5570 | ||
|
|
fceb3ca392 | ||
|
|
34d86c5ab8 | ||
|
|
9cbcf7531f | ||
|
|
bd8f0c1ef3 | ||
|
|
051a5a4adc | ||
|
|
8b4c0c58e4 | ||
|
|
99b41543b8 | ||
|
|
2bbe0f3f09 | ||
|
|
9325fb7990 | ||
|
|
f081435a56 | ||
|
|
b62a1b56ce | ||
|
|
8d7c92c661 | ||
|
|
d9d051cb1e | ||
|
|
cb318b7ef4 | ||
|
|
8f0aa8352a | ||
|
|
c02e236196 | ||
|
|
f51e0b59bd | ||
|
|
32ec42a667 | ||
|
|
9929daf6ce | ||
|
|
939419a0ea | ||
|
|
919fe94fd5 | ||
|
|
df71cb4690 | ||
|
|
4508c61728 | ||
|
|
0ef476b014 | ||
|
|
6f82e96d6a | ||
|
|
a2faae5d62 | ||
|
|
4a3cbcd38a | ||
|
|
c2980bc8cf | ||
|
|
67ae871ce4 | ||
|
|
39ff5e833a |
27
.git-branches.toml
Normal file
27
.git-branches.toml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# More info around this file at https://www.git-town.com/configuration-file
|
||||||
|
|
||||||
|
[branches]
|
||||||
|
main = "main"
|
||||||
|
perennials = []
|
||||||
|
perennial-regex = ""
|
||||||
|
|
||||||
|
[create]
|
||||||
|
new-branch-type = "feature"
|
||||||
|
push-new-branches = false
|
||||||
|
|
||||||
|
[hosting]
|
||||||
|
dev-remote = "origin"
|
||||||
|
# platform = ""
|
||||||
|
# origin-hostname = ""
|
||||||
|
|
||||||
|
[ship]
|
||||||
|
delete-tracking-branch = false
|
||||||
|
strategy = "squash-merge"
|
||||||
|
|
||||||
|
[sync]
|
||||||
|
feature-strategy = "merge"
|
||||||
|
perennial-strategy = "rebase"
|
||||||
|
prototype-strategy = "merge"
|
||||||
|
push-hook = true
|
||||||
|
tags = true
|
||||||
|
upstream = false
|
||||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
@@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
## Issue ticket number and link
|
## Issue ticket number and link
|
||||||
|
|
||||||
|
## Stack
|
||||||
|
|
||||||
|
<!-- branch-stack -->
|
||||||
|
|
||||||
### Checklist
|
### Checklist
|
||||||
- [ ] Is it a bug fix
|
- [ ] Is it a bug fix
|
||||||
- [ ] Is a typo/documentation fix
|
- [ ] Is a typo/documentation fix
|
||||||
|
|||||||
21
.github/workflows/git-town.yml
vendored
Normal file
21
.github/workflows/git-town.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: Git Town
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- '**'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
git-town:
|
||||||
|
name: Display the branch stack
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: git-town/action@v1
|
||||||
|
with:
|
||||||
|
skip-single-stacks: true
|
||||||
10
.github/workflows/golang-test-freebsd.yml
vendored
10
.github/workflows/golang-test-freebsd.yml
vendored
@@ -22,14 +22,20 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
usesh: true
|
usesh: true
|
||||||
copyback: false
|
copyback: false
|
||||||
release: "14.1"
|
release: "14.2"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y go pkgconf xorg
|
pkg install -y curl pkgconf xorg
|
||||||
|
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
|
||||||
|
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
|
||||||
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
|
curl -vLO "$GO_URL"
|
||||||
|
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||||
|
|
||||||
# -x - to print all executed commands
|
# -x - to print all executed commands
|
||||||
# -e - to faile on first error
|
# -e - to faile on first error
|
||||||
run: |
|
run: |
|
||||||
set -e -x
|
set -e -x
|
||||||
|
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||||
time go build -o netbird client/main.go
|
time go build -o netbird client/main.go
|
||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -timeout 1m -failfast ./base62/...
|
||||||
|
|||||||
222
.github/workflows/golang-test-linux.yml
vendored
222
.github/workflows/golang-test-linux.yml
vendored
@@ -146,6 +146,64 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
run: CGO_ENABLED=1 GOARCH=${{ matrix.arch }} CI=true go test -tags devcert -exec 'sudo' -timeout 10m -p 1 $(go list ./... | grep -v -e /management -e /signal -e /relay)
|
||||||
|
|
||||||
|
test_client_on_docker:
|
||||||
|
name: "Client (Docker) / Unit"
|
||||||
|
needs: [build-cache]
|
||||||
|
runs-on: ubuntu-22.04
|
||||||
|
steps:
|
||||||
|
- name: Install Go
|
||||||
|
uses: actions/setup-go@v5
|
||||||
|
with:
|
||||||
|
go-version: "1.23.x"
|
||||||
|
cache: false
|
||||||
|
|
||||||
|
- name: Checkout code
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Get Go environment
|
||||||
|
id: go-env
|
||||||
|
run: |
|
||||||
|
echo "cache_dir=$(go env GOCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
echo "modcache_dir=$(go env GOMODCACHE)" >> $GITHUB_OUTPUT
|
||||||
|
|
||||||
|
- name: Cache Go modules
|
||||||
|
uses: actions/cache/restore@v4
|
||||||
|
id: cache-restore
|
||||||
|
with:
|
||||||
|
path: |
|
||||||
|
${{ steps.go-env.outputs.cache_dir }}
|
||||||
|
${{ steps.go-env.outputs.modcache_dir }}
|
||||||
|
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
||||||
|
restore-keys: |
|
||||||
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
|
- name: Run tests in container
|
||||||
|
env:
|
||||||
|
HOST_GOCACHE: ${{ steps.go-env.outputs.cache_dir }}
|
||||||
|
HOST_GOMODCACHE: ${{ steps.go-env.outputs.modcache_dir }}
|
||||||
|
run: |
|
||||||
|
CONTAINER_GOCACHE="/root/.cache/go-build"
|
||||||
|
CONTAINER_GOMODCACHE="/go/pkg/mod"
|
||||||
|
|
||||||
|
docker run --rm \
|
||||||
|
--cap-add=NET_ADMIN \
|
||||||
|
--privileged \
|
||||||
|
-v $PWD:/app \
|
||||||
|
-w /app \
|
||||||
|
-v "${HOST_GOCACHE}:${CONTAINER_GOCACHE}" \
|
||||||
|
-v "${HOST_GOMODCACHE}:${CONTAINER_GOMODCACHE}" \
|
||||||
|
-e CGO_ENABLED=1 \
|
||||||
|
-e CI=true \
|
||||||
|
-e GOARCH=${GOARCH_TARGET} \
|
||||||
|
-e GOCACHE=${CONTAINER_GOCACHE} \
|
||||||
|
-e GOMODCACHE=${CONTAINER_GOMODCACHE} \
|
||||||
|
golang:1.23-alpine \
|
||||||
|
sh -c ' \
|
||||||
|
apk update; apk add --no-cache \
|
||||||
|
ca-certificates iptables ip6tables dbus dbus-dev libpcap-dev build-base; \
|
||||||
|
go test -buildvcs=false -tags devcert -v -timeout 10m -p 1 $(go list -buildvcs=false ./... | grep -v -e /management -e /signal -e /relay -e /client/ui)
|
||||||
|
'
|
||||||
|
|
||||||
test_relay:
|
test_relay:
|
||||||
name: "Relay / Unit"
|
name: "Relay / Unit"
|
||||||
needs: [build-cache]
|
needs: [build-cache]
|
||||||
@@ -179,13 +237,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -232,13 +283,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -286,13 +330,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -314,6 +351,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags=devcert \
|
go test -tags=devcert \
|
||||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
@@ -353,13 +391,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -380,7 +411,8 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags devcert -run=^$ -bench=. \
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
-timeout 20m ./...
|
-timeout 20m ./...
|
||||||
@@ -396,6 +428,33 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres' ]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Create Docker network
|
||||||
|
run: docker network create promnet
|
||||||
|
|
||||||
|
- name: Start Prometheus Pushgateway
|
||||||
|
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
|
||||||
|
|
||||||
|
- name: Start Prometheus (for Pushgateway forwarding)
|
||||||
|
run: |
|
||||||
|
echo '
|
||||||
|
global:
|
||||||
|
scrape_interval: 15s
|
||||||
|
scrape_configs:
|
||||||
|
- job_name: "pushgateway"
|
||||||
|
static_configs:
|
||||||
|
- targets: ["pushgateway:9091"]
|
||||||
|
remote_write:
|
||||||
|
- url: ${{ secrets.GRAFANA_URL }}
|
||||||
|
basic_auth:
|
||||||
|
username: ${{ secrets.GRAFANA_USER }}
|
||||||
|
password: ${{ secrets.GRAFANA_API_KEY }}
|
||||||
|
' > prometheus.yml
|
||||||
|
|
||||||
|
docker run -d --name prometheus --network promnet \
|
||||||
|
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||||
|
-p 9090:9090 \
|
||||||
|
prom/prometheus
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
@@ -420,13 +479,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -447,11 +499,13 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
|
GIT_BRANCH=${{ github.ref_name }} \
|
||||||
go test -tags=benchmark \
|
go test -tags=benchmark \
|
||||||
-run=^$ \
|
-run=^$ \
|
||||||
-bench=. \
|
-bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
@@ -489,13 +543,6 @@ jobs:
|
|||||||
restore-keys: |
|
restore-keys: |
|
||||||
${{ runner.os }}-gotest-cache-
|
${{ runner.os }}-gotest-cache-
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install 32-bit libpcap
|
|
||||||
if: matrix.arch == '386'
|
|
||||||
run: sudo dpkg --add-architecture i386 && sudo apt update && sudo apt-get install -y libpcap0.8-dev:i386
|
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
@@ -505,89 +552,8 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags=integration \
|
go test -tags=integration \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
test_client_on_docker:
|
|
||||||
name: "Client (Docker) / Unit"
|
|
||||||
needs: [ build-cache ]
|
|
||||||
runs-on: ubuntu-20.04
|
|
||||||
steps:
|
|
||||||
- name: Install Go
|
|
||||||
uses: actions/setup-go@v5
|
|
||||||
with:
|
|
||||||
go-version: "1.23.x"
|
|
||||||
cache: false
|
|
||||||
|
|
||||||
- name: Checkout code
|
|
||||||
uses: actions/checkout@v4
|
|
||||||
|
|
||||||
- name: Get Go environment
|
|
||||||
run: |
|
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
|
||||||
|
|
||||||
- name: Cache Go modules
|
|
||||||
uses: actions/cache/restore@v4
|
|
||||||
with:
|
|
||||||
path: |
|
|
||||||
${{ env.cache }}
|
|
||||||
${{ env.modcache }}
|
|
||||||
key: ${{ runner.os }}-gotest-cache-${{ hashFiles('**/go.sum') }}
|
|
||||||
restore-keys: |
|
|
||||||
${{ runner.os }}-gotest-cache-
|
|
||||||
|
|
||||||
- name: Install dependencies
|
|
||||||
run: sudo apt update && sudo apt install -y -q libgtk-3-dev libayatana-appindicator3-dev libgl1-mesa-dev xorg-dev gcc-multilib libpcap-dev
|
|
||||||
|
|
||||||
- name: Install modules
|
|
||||||
run: go mod tidy
|
|
||||||
|
|
||||||
- name: check git status
|
|
||||||
run: git --no-pager diff --exit-code
|
|
||||||
|
|
||||||
- name: Generate Shared Sock Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o sharedsock-testing.bin ./sharedsock
|
|
||||||
|
|
||||||
- name: Generate RouteManager Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
|
||||||
|
|
||||||
- name: Generate SystemOps Test bin
|
|
||||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
|
||||||
|
|
||||||
- name: Generate nftables Manager Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
|
||||||
|
|
||||||
- name: Generate Engine Test bin
|
|
||||||
run: CGO_ENABLED=1 go test -c -o engine-testing.bin ./client/internal
|
|
||||||
|
|
||||||
- name: Generate Peer Test bin
|
|
||||||
run: CGO_ENABLED=0 go test -c -o peer-testing.bin ./client/internal/peer/
|
|
||||||
|
|
||||||
- run: chmod +x *testing.bin
|
|
||||||
|
|
||||||
- name: Run Shared Sock tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Iface tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
|
||||||
|
|
||||||
- name: Run RouteManager tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run SystemOps tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run nftables Manager tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Engine tests in docker with file store
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Engine tests in docker with sqlite store
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|
||||||
- name: Run Peer tests in docker
|
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
|
|||||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -87,25 +87,25 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: dist/
|
path: dist/
|
||||||
retention-days: 3
|
retention-days: 7
|
||||||
- name: upload linux packages
|
- name: upload linux packages
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-packages
|
name: linux-packages
|
||||||
path: dist/netbird_linux**
|
path: dist/netbird_linux**
|
||||||
retention-days: 3
|
retention-days: 7
|
||||||
- name: upload windows packages
|
- name: upload windows packages
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-packages
|
name: windows-packages
|
||||||
path: dist/netbird_windows**
|
path: dist/netbird_windows**
|
||||||
retention-days: 3
|
retention-days: 7
|
||||||
- name: upload macos packages
|
- name: upload macos packages
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: macos-packages
|
name: macos-packages
|
||||||
path: dist/netbird_darwin**
|
path: dist/netbird_darwin**
|
||||||
retention-days: 3
|
retention-days: 7
|
||||||
|
|
||||||
release_ui:
|
release_ui:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ jobs:
|
|||||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||||
|
grep DisablePromptLogin management.json | grep 'true'
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|||||||
@@ -1,148 +1,64 @@
|
|||||||
# Contributor License Agreement
|
## Contributor License Agreement
|
||||||
|
|
||||||
We are incredibly thankful for the contributions we receive from the community.
|
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
||||||
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
|
||||||
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
||||||
as BSD-3 while allowing NetBird to build a sustainable business.
|
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
||||||
|
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
||||||
NetBird is committed to having a true Open Source Software ("OSS") license for
|
of the terms and conditions outlined below. The Contributor further represents that they are authorized to
|
||||||
our software. A CLA enables NetBird to safely commercialize our products
|
complete this process as described herein.
|
||||||
while keeping a standard OSS license with all the rights that license grants to users: the
|
|
||||||
ability to use the project in their own projects or businesses, to republish modified
|
|
||||||
source, or to completely fork the project.
|
|
||||||
|
|
||||||
This page gives a human-friendly summary of our CLA, details on why we require a CLA, how
|
|
||||||
contributors can sign our CLA, and more. You may view the full legal CLA document (below).
|
|
||||||
|
|
||||||
# Human-friendly summary
|
|
||||||
|
|
||||||
This is a human-readable summary of (and not a substitute for) the full agreement (below).
|
|
||||||
This highlights only some of key terms of the CLA. It has no legal value and you should
|
|
||||||
carefully review all the terms of the actual CLA before agreeing.
|
|
||||||
|
|
||||||
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
|
|
||||||
in commercial products.
|
|
||||||
</li>
|
|
||||||
|
|
||||||
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
|
|
||||||
license to use that patent including within commercial products. You also agree that you
|
|
||||||
have permission to grant this license.
|
|
||||||
</li>
|
|
||||||
|
|
||||||
<li>No Warranty or Support Obligations.
|
|
||||||
By making a contribution, you are not obligating yourself to provide support for the
|
|
||||||
contribution, and you are not taking on any warranty obligations or providing any
|
|
||||||
assurances about how it will perform.
|
|
||||||
</li>
|
|
||||||
|
|
||||||
The CLA does not change the terms of the standard open source license used by our software
|
|
||||||
such as BSD-3 or MIT.
|
|
||||||
You are still free to use our projects within your own projects or businesses, republish
|
|
||||||
modified source, and more.
|
|
||||||
Please reference the appropriate license for the project you're contributing to to learn
|
|
||||||
more.
|
|
||||||
|
|
||||||
# Why require a CLA?
|
|
||||||
|
|
||||||
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
|
||||||
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
|
|
||||||
products.
|
|
||||||
|
|
||||||
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
|
||||||
adopt our projects. At the same time, the CLA ensures that all contributions to our open source projects are licensed
|
|
||||||
under the project's respective open source license, such as BSD-3.
|
|
||||||
|
|
||||||
Requiring a CLA is a common and well-accepted practice in open source. Major open source projects require CLAs such as
|
|
||||||
Apache Software Foundation projects, Facebook projects (such as React), Google projects (including Go), Python, Django,
|
|
||||||
and more. Each of these projects remains licensed under permissive OSS licenses such as MIT, Apache, BSD, and more.
|
|
||||||
|
|
||||||
# Signing the CLA
|
|
||||||
|
|
||||||
Open a pull request ("PR") to any of our open source projects to sign the CLA. A bot will comment on the PR asking you
|
|
||||||
to sign the CLA if you haven't already.
|
|
||||||
|
|
||||||
Follow the steps given by the bot to sign the CLA. This will require you to log in with GitHub (we only request public
|
|
||||||
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
|
||||||
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
|
||||||
|
|
||||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
|
|
||||||
require you to sign again.
|
|
||||||
|
|
||||||
# Legal Terms and Agreement
|
|
||||||
|
|
||||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
|
|
||||||
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
|
|
||||||
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
|
||||||
your own Contributions for any other purpose.
|
|
||||||
|
|
||||||
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
|
||||||
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
|
|
||||||
You reserve all right, title, and interest in and to Your Contributions.
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
```
|
|
||||||
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
|
||||||
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
|
|
||||||
entities that control, are controlled by, or are under common control with that entity are considered
|
|
||||||
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
|
||||||
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
|
||||||
percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
```
|
|
||||||
```
|
|
||||||
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
|
||||||
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
|
|
||||||
or documentation of, any of the products owned or managed by NetBird (the "Work").
|
|
||||||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
|
||||||
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
|
|
||||||
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
|
||||||
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
|
||||||
marked or otherwise designated in writing by You as "Not a Contribution."
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
|
|
||||||
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
|
|
||||||
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
|
||||||
perform, sublicense, and distribute Your Contributions and such derivative works.
|
|
||||||
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
|
## 1 Preamble
|
||||||
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
In order to clarify the IP Rights situation with regard to Contributions from any person or entity, NetBird
|
||||||
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
must have a contributor license agreement on file to be signed by each Contributor, containing the license
|
||||||
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
terms below. This license serves as protection for both the Contributor as well as NetBird and its software users;
|
||||||
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
it does not change Contributor’s rights to use his/her own Contributions for any other purpose.
|
||||||
such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (
|
|
||||||
including a cross-claim or counterclaim in a lawsuit) alleging that your Contribution, or the Work to which you have
|
|
||||||
contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity
|
|
||||||
under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed.
|
|
||||||
|
|
||||||
|
## 2 Definitions
|
||||||
|
2.1 “IP Rights” shall mean all industrial and intellectual property rights, whether registered or not registered, whether created by Contributor or acquired by Contributor from third parties, and similar rights, including (but not limited to) semiconductor property rights, design rights, copyrights (including in the form of database rights and rights to software), all neighbouring rights (Leistungsschutzrechte), trademarks, service marks, titles, internet domain names, trade names and other labelling rights, rights deriving from corresponding applications and registrations of such rights as well as any licenses (Nutzungsrechte) under and entitlements to any such intellectual and industrial property rights.
|
||||||
|
|
||||||
4. You represent that you are legally entitled to grant the above license. If your employer(s) has rights to
|
2.2 "Contribution" shall mean any original work of authorship, including any modifications or additions to an existing work, that is or previously has been intentionally Submitted by Contributor to NetBird for inclusion in, or documentation of any Work.
|
||||||
intellectual property that you create that includes your Contributions, you represent that you have received
|
|
||||||
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
|
||||||
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
|
||||||
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
|
|
||||||
with NetBird.
|
|
||||||
|
|
||||||
|
2.3 "Contributor" shall mean the copyright owner or legal entity authorized by the copyright owner that is concluding this Agreement with NetBird. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
2.4 "Submitted" shall mean any form of electronic, verbal, or written communication sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, NetBird for the purpose of discussing and improving the Work, but excluding communication that is marked or otherwise designated in writing by Contributor as "Not a Contribution".
|
||||||
others). You represent that Your Contribution submissions include complete details of any third-party license or
|
|
||||||
other restriction (including, but not limited to, related patents and trademarks) of which you are personally aware
|
|
||||||
and which are associated with any part of Your Contributions.
|
|
||||||
|
|
||||||
|
2.5 "Work" means any of the products owned or managed by NetBird, in particular, but not exclusively, software.
|
||||||
|
|
||||||
6. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
|
## 3 Licenses
|
||||||
You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in
|
3.1 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable license to reproduce by any means and in any form, in whole or in part, permanently or temporarily, the Contributions (including loading, displaying, executing, transmitting or storing works for the purpose of executing and processing data or transferring them to video, audio and other data carriers), including the right to distribute, display and present such Contributions and make them available to the public (e.g. via the internet) and to transmit and display such Contributions by any means. The license also includes the right to modify, translate, adapt, edit and otherwise alter the Contributions and to use these results in the same manner as the original Contributions and derivative works. Except for licenses in patents acc. to Sec. 3, such license refers to any IP Rights in the Contributions and derivative works. The Contributor acknowledges that NetBird is not required to credit them by name for their Contribution and agrees to waive any moral rights associated with their Contribution in relation to NetBird or its sublicensees.
|
||||||
writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
|
||||||
express or implied, including, without limitation, any warranties or conditions of TITLE, NON- INFRINGEMENT,
|
|
||||||
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
|
|
||||||
|
3.2 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license in the Contributions to make, have made, use, sell, offer to sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by the Contributor which are necessarily infringed by Contributor‘s Contribution(s) alone or by combination of Contributor’s Contribution(s) with the Work to which such Contribution(s) was Submitted.
|
||||||
|
|
||||||
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
|
3.3 NetBird hereby accepts such licenses.
|
||||||
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
|
||||||
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
|
||||||
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
|
||||||
|
|
||||||
|
## 4 Contributor’s Representations
|
||||||
|
4.1 Contributor represents that Contributor is legally entitled to grant the above license. If Contributor’s employer has IP Rights to Contributor’s Contributions, Contributor represent that he/she has received permission to make Contributions on behalf of such employer, that such employer has waived such IP Rights to the Contributions of Contributor to NetBird, or that such employer has executed a separate contributor license agreement with NetBird.
|
||||||
|
|
||||||
|
4.2 Contributor represents that any Contribution is his/her original creation.
|
||||||
|
|
||||||
|
4.3 Contributor represents to his/her best knowledge that any Contribution does not violate any third party IP Rights.
|
||||||
|
|
||||||
|
4.4 Contributor represents that any Contribution submission includes complete details of any third-party license or other restriction (including, but not limited to, related patents and trademarks) of which Contributor is personally aware and which are associated with any part of the Contribution.
|
||||||
|
|
||||||
|
4.5 The Contributor represents that their Contribution does not include any work distributed under a copyleft license.
|
||||||
|
|
||||||
|
## 5 Information obligation
|
||||||
|
Contributor agrees to notify NetBird of any facts or circumstances of which Contributor become aware that would make these representations inaccurate in any respect.
|
||||||
|
|
||||||
|
## 6 Submission of Third-Party works
|
||||||
|
Should Contributor wish to submit work that is not Contributor’s original creation, Contributor may submit it to NetBird separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which Contributor are personally aware, and conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||||
|
|
||||||
|
## 7 No Consideration
|
||||||
|
Unless compensation is mandatory under statutory law, no compensation for any license under this agreement shall be payable.
|
||||||
|
|
||||||
|
## 8 Final Provisions
|
||||||
|
8.1 Laws. This Agreement is governed by the laws of the Federal Republic of Germany.
|
||||||
|
|
||||||
|
8.2 Venue. Place of jurisdiction shall, to the extent legally permissible, be Berlin, Germany.
|
||||||
|
|
||||||
|
8.3 Severability. If any provision in this agreement is unlawful, invalid or ineffective, it shall not affect the enforceability or effectiveness of the remainder of this agreement. The parties agree to replace any unlawful, invalid or ineffective provision with a provision that comes as close as possible to the commercial intent and purpose of the original provision. This section also applies accordingly to any gaps in the contract.
|
||||||
|
|
||||||
|
8.4 Variations. Any variations, amendments or supplements to this Agreement must be in writing. This also applies to any variation of this Section 8.4.
|
||||||
|
|
||||||
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
|
|
||||||
representations inaccurate in any respect.
|
|
||||||
|
|||||||
28
README.md
28
README.md
@@ -12,7 +12,7 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
@@ -29,13 +29,13 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
|
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||||
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
|
New: NetBird Kubernetes Operator
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
@@ -57,16 +57,16 @@
|
|||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
| Connectivity | Management | Security | Automation | Platforms |
|
| Connectivity | Management | Security | Automation| Platforms |
|
||||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
|----|----|----|----|----|
|
||||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||||
|
|
||||||
### Quickstart with NetBird Cloud
|
### Quickstart with NetBird Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -1,5 +1,6 @@
|
|||||||
FROM alpine:3.21.3
|
FROM alpine:3.21.3
|
||||||
RUN apk add --no-cache ca-certificates iptables ip6tables
|
# iproute2: busybox doesn't display ip rules properly
|
||||||
|
RUN apk add --no-cache ca-certificates ip6tables iproute2 iptables
|
||||||
ENV NB_FOREGROUND_MODE=true
|
ENV NB_FOREGROUND_MODE=true
|
||||||
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
ENTRYPOINT [ "/usr/local/bin/netbird","up"]
|
||||||
COPY netbird /usr/local/bin/netbird
|
COPY netbird /usr/local/bin/netbird
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ type Anonymizer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||||
// 192.51.100.0, 100::
|
// 198.51.100.0, 100::
|
||||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,9 +11,12 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errCloseConnection = "Failed to close connection: %v"
|
const errCloseConnection = "Failed to close connection: %v"
|
||||||
@@ -326,3 +329,34 @@ func formatDuration(d time.Duration) string {
|
|||||||
s := d / time.Second
|
s := d / time.Second
|
||||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
|
||||||
|
var networkMap *mgmProto.NetworkMap
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if connectClient != nil {
|
||||||
|
networkMap, err = connectClient.GetLatestNetworkMap()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get latest network map: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bundleGenerator := debug.NewBundleGenerator(
|
||||||
|
debug.GeneratorDependencies{
|
||||||
|
InternalConfig: config,
|
||||||
|
StatusRecorder: recorder,
|
||||||
|
NetworkMap: networkMap,
|
||||||
|
LogFile: logFilePath,
|
||||||
|
},
|
||||||
|
debug.BundleConfig{
|
||||||
|
IncludeSystemInfo: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
path, err := bundleGenerator.Generate()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to generate debug bundle: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||||
|
}
|
||||||
|
|||||||
39
client/cmd/debug_unix.go
Normal file
39
client/cmd/debug_unix.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetupDebugHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
) {
|
||||||
|
usr1Ch := make(chan os.Signal, 1)
|
||||||
|
|
||||||
|
signal.Notify(usr1Ch, syscall.SIGUSR1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-usr1Ch:
|
||||||
|
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
|
||||||
|
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
126
client/cmd/debug_windows.go
Normal file
126
client/cmd/debug_windows.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
|
||||||
|
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
|
||||||
|
|
||||||
|
waitTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
|
||||||
|
// Example usage with PowerShell:
|
||||||
|
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
|
||||||
|
// $evt.Set()
|
||||||
|
// $evt.Close()
|
||||||
|
func SetupDebugHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
) {
|
||||||
|
env := os.Getenv(envListenEvent)
|
||||||
|
if env == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
listenEvent, err := strconv.ParseBool(env)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !listenEvent {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: restrict access by ACL
|
||||||
|
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
||||||
|
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
|
||||||
|
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
||||||
|
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
|
||||||
|
} else {
|
||||||
|
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventHandle == windows.InvalidHandle {
|
||||||
|
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
|
||||||
|
|
||||||
|
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
eventHandle windows.Handle,
|
||||||
|
) {
|
||||||
|
defer func() {
|
||||||
|
if err := windows.CloseHandle(eventHandle); err != nil {
|
||||||
|
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case windows.WAIT_OBJECT_0:
|
||||||
|
log.Info("Received signal on debug event. Triggering debug bundle generation.")
|
||||||
|
|
||||||
|
// reset the event so it can be triggered again later (manual reset == 1)
|
||||||
|
if err := windows.ResetEvent(eventHandle); err != nil {
|
||||||
|
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||||
|
case uint32(windows.WAIT_TIMEOUT):
|
||||||
|
|
||||||
|
default:
|
||||||
|
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
|
||||||
|
select {
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
}
|
||||||
|
|
||||||
var loginCmd = &cobra.Command{
|
var loginCmd = &cobra.Command{
|
||||||
Use: "login",
|
Use: "login",
|
||||||
Short: "login to the Netbird Management Service (first run)",
|
Short: "login to the Netbird Management Service (first run)",
|
||||||
@@ -51,6 +55,9 @@ var loginCmd = &cobra.Command{
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// update host's static platform and system information
|
||||||
|
system.UpdateStaticInfo()
|
||||||
|
|
||||||
ic := internal.ConfigInput{
|
ic := internal.ConfigInput{
|
||||||
ManagementURL: managementURL,
|
ManagementURL: managementURL,
|
||||||
AdminURL: adminURL,
|
AdminURL: adminURL,
|
||||||
@@ -127,7 +134,7 @@ var loginCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -198,7 +205,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||||
@@ -212,19 +219,27 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||||
var codeMsg string
|
var codeMsg string
|
||||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
if noBrowser {
|
||||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
|
||||||
verificationURIComplete + " " + codeMsg)
|
} else {
|
||||||
|
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||||
|
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||||
|
verificationURIComplete + " " + codeMsg)
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
if err := open.Run(verificationURIComplete); err != nil {
|
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
if !noBrowser {
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
if err := open.Run(verificationURIComplete); err != nil {
|
||||||
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
|
|
||||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
|
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||||
|
|||||||
@@ -16,12 +16,17 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
|
"github.com/netbirdio/netbird/client/system"
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (p *program) Start(svc service.Service) error {
|
func (p *program) Start(svc service.Service) error {
|
||||||
// Start should not block. Do the actual work async.
|
// Start should not block. Do the actual work async.
|
||||||
log.Info("starting Netbird service") //nolint
|
log.Info("starting Netbird service") //nolint
|
||||||
|
|
||||||
|
// Collect static system and platform information
|
||||||
|
system.UpdateStaticInfo()
|
||||||
|
|
||||||
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
// in any case, even if configuration does not exists we run daemon to serve CLI gRPC API.
|
||||||
p.serv = grpc.NewServer()
|
p.serv = grpc.NewServer()
|
||||||
|
|
||||||
@@ -115,6 +120,7 @@ var runCmd = &cobra.Command{
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
|
SetupDebugHandler(ctx, nil, nil, nil, logFile)
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -6,14 +6,17 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
@@ -31,7 +34,7 @@ import (
|
|||||||
|
|
||||||
func startTestingServices(t *testing.T) string {
|
func startTestingServices(t *testing.T) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
config := &mgmt.Config{}
|
config := &types.Config{}
|
||||||
_, err := util.ReadJson("../testdata/management.json", config)
|
_, err := util.ReadJson("../testdata/management.json", config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -66,7 +69,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
return s, lis
|
return s, lis
|
||||||
}
|
}
|
||||||
|
|
||||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", ":0")
|
lis, err := net.Listen("tcp", ":0")
|
||||||
@@ -89,14 +92,24 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock())
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.Settings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,12 +32,16 @@ const (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
dnsLabelsFlag = "extra-dns-labels"
|
dnsLabelsFlag = "extra-dns-labels"
|
||||||
|
|
||||||
|
noBrowserFlag = "no-browser"
|
||||||
|
noBrowserDesc = "do not open the browser for SSO login"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
foregroundMode bool
|
foregroundMode bool
|
||||||
dnsLabels []string
|
dnsLabels []string
|
||||||
dnsLabelsValidated domain.List
|
dnsLabelsValidated domain.List
|
||||||
|
noBrowser bool
|
||||||
|
|
||||||
upCmd = &cobra.Command{
|
upCmd = &cobra.Command{
|
||||||
Use: "up",
|
Use: "up",
|
||||||
@@ -65,6 +69,9 @@ func init() {
|
|||||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||||
`or --extra-dns-labels ""`,
|
`or --extra-dns-labels ""`,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func upFunc(cmd *cobra.Command, args []string) error {
|
func upFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -212,6 +219,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
r.GetFullStatus()
|
r.GetFullStatus()
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||||
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
return connectClient.Run(nil)
|
return connectClient.Run(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -349,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
|
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -10,17 +10,18 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
@@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||||
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) AddPeerFiltering(
|
func (m *aclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
|
|||||||
@@ -96,36 +96,36 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
_ string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if !destination.Addr().Is4() {
|
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -196,13 +196,13 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
@@ -242,6 +242,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return m.router.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateSet updates the set with the given prefixes
|
||||||
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.UpdateSet(set, prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
func getConntrackEstablished() []string {
|
func getConntrackEstablished() []string {
|
||||||
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
return []string{"-m", "conntrack", "--ctstate", "RELATED,ESTABLISHED", "-j", "ACCEPT"}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -97,7 +97,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -38,10 +38,12 @@ const (
|
|||||||
routingFinalForwardJump = "ACCEPT"
|
routingFinalForwardJump = "ACCEPT"
|
||||||
routingFinalNatJump = "MASQUERADE"
|
routingFinalNatJump = "MASQUERADE"
|
||||||
|
|
||||||
jumpManglePre = "jump-mangle-pre"
|
jumpManglePre = "jump-mangle-pre"
|
||||||
jumpNatPre = "jump-nat-pre"
|
jumpNatPre = "jump-nat-pre"
|
||||||
jumpNatPost = "jump-nat-post"
|
jumpNatPost = "jump-nat-post"
|
||||||
matchSet = "--match-set"
|
markManglePre = "mark-mangle-pre"
|
||||||
|
markManglePost = "mark-mangle-post"
|
||||||
|
matchSet = "--match-set"
|
||||||
|
|
||||||
dnatSuffix = "_dnat"
|
dnatSuffix = "_dnat"
|
||||||
snatSuffix = "_snat"
|
snatSuffix = "_snat"
|
||||||
@@ -55,18 +57,18 @@ type ruleInfo struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type routeFilteringRuleParams struct {
|
type routeFilteringRuleParams struct {
|
||||||
Sources []netip.Prefix
|
Source firewall.Network
|
||||||
Destination netip.Prefix
|
Destination firewall.Network
|
||||||
Proto firewall.Protocol
|
Proto firewall.Protocol
|
||||||
SPort *firewall.Port
|
SPort *firewall.Port
|
||||||
DPort *firewall.Port
|
DPort *firewall.Port
|
||||||
Direction firewall.RuleDirection
|
Direction firewall.RuleDirection
|
||||||
Action firewall.Action
|
Action firewall.Action
|
||||||
SetName string
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type routeRules map[string][]string
|
type routeRules map[string][]string
|
||||||
|
|
||||||
|
// the ipset library currently does not support comments, so we use the name only (string)
|
||||||
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
type ipsetCounter = refcounter.Counter[string, []netip.Prefix, struct{}]
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
@@ -115,45 +117,51 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
|||||||
return fmt.Errorf("create containers: %w", err)
|
return fmt.Errorf("create containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.setupDataPlaneMark(); err != nil {
|
||||||
|
log.Errorf("failed to set up data plane mark: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
r.updateState()
|
r.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
var setName string
|
var source firewall.Network
|
||||||
if len(sources) > 1 {
|
if len(sources) > 1 {
|
||||||
setName = firewall.GenerateSetName(sources)
|
source.Set = firewall.NewPrefixSet(sources)
|
||||||
if _, err := r.ipsetCounter.Increment(setName, sources); err != nil {
|
} else if len(sources) > 0 {
|
||||||
return nil, fmt.Errorf("create or get ipset: %w", err)
|
source.Prefix = sources[0]
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
params := routeFilteringRuleParams{
|
params := routeFilteringRuleParams{
|
||||||
Sources: sources,
|
Source: source,
|
||||||
Destination: destination,
|
Destination: destination,
|
||||||
Proto: proto,
|
Proto: proto,
|
||||||
SPort: sPort,
|
SPort: sPort,
|
||||||
DPort: dPort,
|
DPort: dPort,
|
||||||
Action: action,
|
Action: action,
|
||||||
SetName: setName,
|
|
||||||
}
|
}
|
||||||
|
|
||||||
rule := genRouteFilteringRuleSpec(params)
|
rule, err := r.genRouteRuleSpec(params, sources)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("generate route rule spec: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
// Insert DROP rules at the beginning, append ACCEPT rules at the end
|
||||||
var err error
|
|
||||||
if action == firewall.ActionDrop {
|
if action == firewall.ActionDrop {
|
||||||
// after the established rule
|
// after the established rule
|
||||||
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
err = r.iptablesClient.Insert(tableFilter, chainRTFWDIN, 2, rule...)
|
||||||
@@ -176,17 +184,13 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
ruleKey := rule.ID()
|
ruleKey := rule.ID()
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
setName := r.findSetNameInRule(rule)
|
|
||||||
|
|
||||||
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
if err := r.iptablesClient.Delete(tableFilter, chainRTFWDIN, rule...); err != nil {
|
||||||
return fmt.Errorf("delete route rule: %v", err)
|
return fmt.Errorf("delete route rule: %v", err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
if setName != "" {
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
return fmt.Errorf("failed to remove ipset: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("route rule %s not found", ruleKey)
|
log.Debugf("route rule %s not found", ruleKey)
|
||||||
@@ -197,13 +201,26 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) findSetNameInRule(rule []string) string {
|
func (r *router) decrementSetCounter(rule []string) error {
|
||||||
for i, arg := range rule {
|
sets := r.findSets(rule)
|
||||||
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
var merr *multierror.Error
|
||||||
return rule[i+3]
|
for _, setName := range sets {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("decrement counter: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) findSets(rule []string) []string {
|
||||||
|
var sets []string
|
||||||
|
for i, arg := range rule {
|
||||||
|
if arg == "-m" && i+3 < len(rule) && rule[i+1] == "set" && rule[i+2] == matchSet {
|
||||||
|
sets = append(sets, rule[i+3])
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
func (r *router) createIpSet(setName string, sources []netip.Prefix) error {
|
||||||
@@ -224,6 +241,8 @@ func (r *router) deleteIpSet(setName string) error {
|
|||||||
if err := ipset.Destroy(setName); err != nil {
|
if err := ipset.Destroy(setName); err != nil {
|
||||||
return fmt.Errorf("destroy set %s: %w", setName, err)
|
return fmt.Errorf("destroy set %s: %w", setName, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
log.Debugf("Deleted unused ipset %s", setName)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -263,12 +282,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
log.Errorf("%v", err)
|
log.Errorf("%v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if pair.Masquerade {
|
||||||
return fmt.Errorf("remove nat rule: %w", err)
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
}
|
return fmt.Errorf("remove nat rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse nat rule: %w", err)
|
return fmt.Errorf("remove inverse nat rule: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
@@ -306,8 +327,10 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
|
||||||
log.Debugf("legacy forwarding rule %s not found", ruleKey)
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -347,12 +370,16 @@ func (r *router) Reset() error {
|
|||||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
r.rules = make(map[string][]string)
|
|
||||||
|
|
||||||
if err := r.ipsetCounter.Flush(); err != nil {
|
if err := r.ipsetCounter.Flush(); err != nil {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules = make(map[string][]string)
|
||||||
r.updateState()
|
r.updateState()
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
@@ -422,6 +449,57 @@ func (r *router) createContainers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupDataPlaneMark configures the fwmark for the data plane
|
||||||
|
func (r *router) setupDataPlaneMark() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
preRule := []string{
|
||||||
|
"-i", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "NEW",
|
||||||
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
r.rules[markManglePre] = preRule
|
||||||
|
}
|
||||||
|
|
||||||
|
postRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "NEW",
|
||||||
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
r.rules[markManglePost] = postRule
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) cleanupDataPlaneMark() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
if preRule, exists := r.rules[markManglePre]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, markManglePre)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if postRule, exists := r.rules[markManglePost]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, markManglePost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) addPostroutingRules() error {
|
func (r *router) addPostroutingRules() error {
|
||||||
// First rule for outbound masquerade
|
// First rule for outbound masquerade
|
||||||
rule1 := []string{
|
rule1 := []string{
|
||||||
@@ -463,7 +541,7 @@ func (r *router) insertEstablishedRule(chain string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) addJumpRules() error {
|
func (r *router) addJumpRules() error {
|
||||||
// Jump to NAT chain
|
// Jump to nat chain
|
||||||
natRule := []string{"-j", chainRTNAT}
|
natRule := []string{"-j", chainRTNAT}
|
||||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||||
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||||
@@ -537,12 +615,26 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
rule = append(rule,
|
rule = append(rule,
|
||||||
"-m", "conntrack",
|
"-m", "conntrack",
|
||||||
"--ctstate", "NEW",
|
"--ctstate", "NEW",
|
||||||
"-s", pair.Source.String(),
|
)
|
||||||
"-d", pair.Destination.String(),
|
sourceExp, err := r.applyNetwork("-s", pair.Source, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply network -s: %w", err)
|
||||||
|
}
|
||||||
|
destExp, err := r.applyNetwork("-d", pair.Destination, nil)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply network -d: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
rule = append(rule, sourceExp...)
|
||||||
|
rule = append(rule, destExp...)
|
||||||
|
rule = append(rule,
|
||||||
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
"-j", "MARK", "--set-mark", fmt.Sprintf("%#x", markValue),
|
||||||
)
|
)
|
||||||
|
|
||||||
if err := r.iptablesClient.Append(tableMangle, chainRTPRE, rule...); err != nil {
|
// Ensure nat rules come first, so the mark can be overwritten.
|
||||||
|
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||||
|
if err := r.iptablesClient.Insert(tableMangle, chainRTPRE, 1, rule...); err != nil {
|
||||||
|
// TODO: rollback ipset counter
|
||||||
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while adding marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -560,6 +652,10 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
return fmt.Errorf("error while removing marking rule for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement ipset counter: %w", err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("marking rule %s not found", ruleKey)
|
log.Debugf("marking rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
@@ -725,17 +821,21 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
func (r *router) genRouteRuleSpec(params routeFilteringRuleParams, sources []netip.Prefix) ([]string, error) {
|
||||||
var rule []string
|
var rule []string
|
||||||
|
|
||||||
if params.SetName != "" {
|
sourceExp, err := r.applyNetwork("-s", params.Source, sources)
|
||||||
rule = append(rule, "-m", "set", matchSet, params.SetName, "src")
|
if err != nil {
|
||||||
} else if len(params.Sources) > 0 {
|
return nil, fmt.Errorf("apply network -s: %w", err)
|
||||||
source := params.Sources[0]
|
|
||||||
rule = append(rule, "-s", source.String())
|
}
|
||||||
|
destExp, err := r.applyNetwork("-d", params.Destination, nil)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply network -d: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
rule = append(rule, "-d", params.Destination.String())
|
rule = append(rule, sourceExp...)
|
||||||
|
rule = append(rule, destExp...)
|
||||||
|
|
||||||
if params.Proto != firewall.ProtocolALL {
|
if params.Proto != firewall.ProtocolALL {
|
||||||
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
rule = append(rule, "-p", strings.ToLower(string(params.Proto)))
|
||||||
@@ -745,7 +845,47 @@ func genRouteFilteringRuleSpec(params routeFilteringRuleParams) []string {
|
|||||||
|
|
||||||
rule = append(rule, "-j", actionToStr(params.Action))
|
rule = append(rule, "-j", actionToStr(params.Action))
|
||||||
|
|
||||||
return rule
|
return rule, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) applyNetwork(flag string, network firewall.Network, prefixes []netip.Prefix) ([]string, error) {
|
||||||
|
direction := "src"
|
||||||
|
if flag == "-d" {
|
||||||
|
direction = "dst"
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.IsSet() {
|
||||||
|
if _, err := r.ipsetCounter.Increment(network.Set.HashedName(), prefixes); err != nil {
|
||||||
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return []string{"-m", "set", matchSet, network.Set.HashedName(), direction}, nil
|
||||||
|
}
|
||||||
|
if network.IsPrefix() {
|
||||||
|
return []string{flag, network.Prefix.String()}, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// nolint:nilnil
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
for _, prefix := range prefixes {
|
||||||
|
// TODO: Implement IPv6 support
|
||||||
|
if prefix.Addr().Is6() {
|
||||||
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := ipset.AddPrefix(set.HashedName(), prefix); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("increment ipset counter: %w", err))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if merr == nil {
|
||||||
|
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func applyPort(flag string, port *firewall.Port) []string {
|
func applyPort(flag string, port *firewall.Port) []string {
|
||||||
|
|||||||
@@ -46,7 +46,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
// 5. jump rule to PRE nat chain
|
// 5. jump rule to PRE nat chain
|
||||||
// 6. static outbound masquerade rule
|
// 6. static outbound masquerade rule
|
||||||
// 7. static return masquerade rule
|
// 7. static return masquerade rule
|
||||||
require.Len(t, manager.rules, 7, "should have created rules map")
|
// 8. mangle prerouting mark rule
|
||||||
|
// 9. mangle postrouting mark rule
|
||||||
|
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
@@ -58,8 +60,8 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
|
|
||||||
pair := firewall.RouterPair{
|
pair := firewall.RouterPair{
|
||||||
ID: "abc",
|
ID: "abc",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.100.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -330,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
@@ -345,23 +347,29 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
assert.NoError(t, err, "Failed to check rule existence")
|
assert.NoError(t, err, "Failed to check rule existence")
|
||||||
assert.True(t, exists, "Rule not found in iptables")
|
assert.True(t, exists, "Rule not found in iptables")
|
||||||
|
|
||||||
|
var source firewall.Network
|
||||||
|
if len(tt.sources) > 1 {
|
||||||
|
source.Set = firewall.NewPrefixSet(tt.sources)
|
||||||
|
} else if len(tt.sources) > 0 {
|
||||||
|
source.Prefix = tt.sources[0]
|
||||||
|
}
|
||||||
// Verify rule content
|
// Verify rule content
|
||||||
params := routeFilteringRuleParams{
|
params := routeFilteringRuleParams{
|
||||||
Sources: tt.sources,
|
Source: source,
|
||||||
Destination: tt.destination,
|
Destination: firewall.Network{Prefix: tt.destination},
|
||||||
Proto: tt.proto,
|
Proto: tt.proto,
|
||||||
SPort: tt.sPort,
|
SPort: tt.sPort,
|
||||||
DPort: tt.dPort,
|
DPort: tt.dPort,
|
||||||
Action: tt.action,
|
Action: tt.action,
|
||||||
SetName: "",
|
|
||||||
}
|
}
|
||||||
|
|
||||||
expectedRule := genRouteFilteringRuleSpec(params)
|
expectedRule, err := r.genRouteRuleSpec(params, nil)
|
||||||
|
require.NoError(t, err, "Failed to generate expected rule spec")
|
||||||
|
|
||||||
if tt.expectSet {
|
if tt.expectSet {
|
||||||
setName := firewall.GenerateSetName(tt.sources)
|
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||||
params.SetName = setName
|
expectedRule, err = r.genRouteRuleSpec(params, nil)
|
||||||
expectedRule = genRouteFilteringRuleSpec(params)
|
require.NoError(t, err, "Failed to generate expected rule spec with set")
|
||||||
|
|
||||||
// Check if the set was created
|
// Check if the set was created
|
||||||
_, exists := r.ipsetCounter.Get(setName)
|
_, exists := r.ipsetCounter.Get(setName)
|
||||||
@@ -376,3 +384,62 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestFindSetNameInRule(t *testing.T) {
|
||||||
|
r := &router{}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
rule []string
|
||||||
|
expected []string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Basic rule with two sets",
|
||||||
|
rule: []string{
|
||||||
|
"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-m", "set", "--match-set", "nb-2e5a2a05", "src",
|
||||||
|
"-m", "set", "--match-set", "nb-349ae051", "dst", "-m", "tcp", "--dport", "8080", "-j", "ACCEPT",
|
||||||
|
},
|
||||||
|
expected: []string{"nb-2e5a2a05", "nb-349ae051"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "No sets",
|
||||||
|
rule: []string{"-A", "NETBIRD-RT-FWD-IN", "-p", "tcp", "-j", "ACCEPT"},
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Multiple sets with different positions",
|
||||||
|
rule: []string{
|
||||||
|
"-m", "set", "--match-set", "set1", "src", "-p", "tcp",
|
||||||
|
"-m", "set", "--match-set", "set-abc123", "dst", "-j", "ACCEPT",
|
||||||
|
},
|
||||||
|
expected: []string{"set1", "set-abc123"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Boundary case - sequence appears at end",
|
||||||
|
rule: []string{"-p", "tcp", "-m", "set", "--match-set", "final-set"},
|
||||||
|
expected: []string{"final-set"},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Incomplete pattern - missing set name",
|
||||||
|
rule: []string{"-p", "tcp", "-m", "set", "--match-set"},
|
||||||
|
expected: []string{},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
result := r.findSets(tc.rule)
|
||||||
|
|
||||||
|
if len(result) != len(tc.expected) {
|
||||||
|
t.Errorf("Expected %d sets, got %d. Sets found: %v", len(tc.expected), len(result), result)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
for i, set := range result {
|
||||||
|
if set != tc.expected[i] {
|
||||||
|
t.Errorf("Expected set %q at position %d, got %q", tc.expected[i], i, set)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -1,13 +1,10 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha256"
|
|
||||||
"encoding/hex"
|
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"sort"
|
"sort"
|
||||||
"strings"
|
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
@@ -43,6 +40,18 @@ const (
|
|||||||
// Action is the action to be taken on a rule
|
// Action is the action to be taken on a rule
|
||||||
type Action int
|
type Action int
|
||||||
|
|
||||||
|
// String returns the string representation of the action
|
||||||
|
func (a Action) String() string {
|
||||||
|
switch a {
|
||||||
|
case ActionAccept:
|
||||||
|
return "accept"
|
||||||
|
case ActionDrop:
|
||||||
|
return "drop"
|
||||||
|
default:
|
||||||
|
return "unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
// ActionAccept is the action to accept a packet
|
// ActionAccept is the action to accept a packet
|
||||||
ActionAccept Action = iota
|
ActionAccept Action = iota
|
||||||
@@ -50,6 +59,33 @@ const (
|
|||||||
ActionDrop
|
ActionDrop
|
||||||
)
|
)
|
||||||
|
|
||||||
|
// Network is a rule destination, either a set or a prefix
|
||||||
|
type Network struct {
|
||||||
|
Set Set
|
||||||
|
Prefix netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the destination
|
||||||
|
func (d Network) String() string {
|
||||||
|
if d.Prefix.IsValid() {
|
||||||
|
return d.Prefix.String()
|
||||||
|
}
|
||||||
|
if d.IsSet() {
|
||||||
|
return d.Set.HashedName()
|
||||||
|
}
|
||||||
|
return "<invalid network>"
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsSet returns true if the destination is a set
|
||||||
|
func (d Network) IsSet() bool {
|
||||||
|
return d.Set != Set{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsPrefix returns true if the destination is a valid prefix
|
||||||
|
func (d Network) IsPrefix() bool {
|
||||||
|
return d.Prefix.IsValid()
|
||||||
|
}
|
||||||
|
|
||||||
// Manager is the high level abstraction of a firewall manager
|
// Manager is the high level abstraction of a firewall manager
|
||||||
//
|
//
|
||||||
// It declares methods which handle actions required by the
|
// It declares methods which handle actions required by the
|
||||||
@@ -65,13 +101,13 @@ type Manager interface {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
AddPeerFiltering(
|
AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
dPort *Port,
|
dPort *Port,
|
||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]Rule, error)
|
) ([]Rule, error)
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -80,7 +116,14 @@ type Manager interface {
|
|||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination Network,
|
||||||
|
proto Protocol,
|
||||||
|
sPort, dPort *Port,
|
||||||
|
action Action,
|
||||||
|
) (Rule, error)
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule
|
// DeleteRouteRule deletes a routing rule
|
||||||
DeleteRouteRule(rule Rule) error
|
DeleteRouteRule(rule Rule) error
|
||||||
@@ -111,6 +154,9 @@ type Manager interface {
|
|||||||
|
|
||||||
// DeleteDNATRule deletes a DNAT rule
|
// DeleteDNATRule deletes a DNAT rule
|
||||||
DeleteDNATRule(Rule) error
|
DeleteDNATRule(Rule) error
|
||||||
|
|
||||||
|
// UpdateSet updates the set with the given prefixes
|
||||||
|
UpdateSet(hash Set, prefixes []netip.Prefix) error
|
||||||
}
|
}
|
||||||
|
|
||||||
func GenKey(format string, pair RouterPair) string {
|
func GenKey(format string, pair RouterPair) string {
|
||||||
@@ -145,22 +191,6 @@ func SetLegacyManagement(router LegacyManager, isLegacy bool) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// GenerateSetName generates a unique name for an ipset based on the given sources.
|
|
||||||
func GenerateSetName(sources []netip.Prefix) string {
|
|
||||||
// sort for consistent naming
|
|
||||||
SortPrefixes(sources)
|
|
||||||
|
|
||||||
var sourcesStr strings.Builder
|
|
||||||
for _, src := range sources {
|
|
||||||
sourcesStr.WriteString(src.String())
|
|
||||||
}
|
|
||||||
|
|
||||||
hash := sha256.Sum256([]byte(sourcesStr.String()))
|
|
||||||
shortHash := hex.EncodeToString(hash[:])[:8]
|
|
||||||
|
|
||||||
return fmt.Sprintf("nb-%s", shortHash)
|
|
||||||
}
|
|
||||||
|
|
||||||
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
// MergeIPRanges merges overlapping IP ranges and returns a slice of non-overlapping netip.Prefix
|
||||||
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
func MergeIPRanges(prefixes []netip.Prefix) []netip.Prefix {
|
||||||
if len(prefixes) == 0 {
|
if len(prefixes) == 0 {
|
||||||
|
|||||||
@@ -20,8 +20,8 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result1 := manager.GenerateSetName(prefixes1)
|
result1 := manager.NewPrefixSet(prefixes1)
|
||||||
result2 := manager.GenerateSetName(prefixes2)
|
result2 := manager.NewPrefixSet(prefixes2)
|
||||||
|
|
||||||
if result1 != result2 {
|
if result1 != result2 {
|
||||||
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
t.Errorf("Different orders produced different hashes: %s != %s", result1, result2)
|
||||||
@@ -34,9 +34,9 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
netip.MustParsePrefix("10.0.0.0/8"),
|
netip.MustParsePrefix("10.0.0.0/8"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result := manager.GenerateSetName(prefixes)
|
result := manager.NewPrefixSet(prefixes)
|
||||||
|
|
||||||
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result)
|
matched, err := regexp.MatchString(`^nb-[0-9a-f]{8}$`, result.HashedName())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Error matching regex: %v", err)
|
t.Fatalf("Error matching regex: %v", err)
|
||||||
}
|
}
|
||||||
@@ -46,8 +46,8 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
t.Run("Empty input produces consistent result", func(t *testing.T) {
|
||||||
result1 := manager.GenerateSetName([]netip.Prefix{})
|
result1 := manager.NewPrefixSet([]netip.Prefix{})
|
||||||
result2 := manager.GenerateSetName([]netip.Prefix{})
|
result2 := manager.NewPrefixSet([]netip.Prefix{})
|
||||||
|
|
||||||
if result1 != result2 {
|
if result1 != result2 {
|
||||||
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
t.Errorf("Empty input produced inconsistent results: %s != %s", result1, result2)
|
||||||
@@ -64,8 +64,8 @@ func TestGenerateSetName(t *testing.T) {
|
|||||||
netip.MustParsePrefix("192.168.1.0/24"),
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
}
|
}
|
||||||
|
|
||||||
result1 := manager.GenerateSetName(prefixes1)
|
result1 := manager.NewPrefixSet(prefixes1)
|
||||||
result2 := manager.GenerateSetName(prefixes2)
|
result2 := manager.NewPrefixSet(prefixes2)
|
||||||
|
|
||||||
if result1 != result2 {
|
if result1 != result2 {
|
||||||
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
t.Errorf("Different orders of IPv4 and IPv6 produced different hashes: %s != %s", result1, result2)
|
||||||
|
|||||||
@@ -1,15 +1,13 @@
|
|||||||
package manager
|
package manager
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net/netip"
|
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/route"
|
"github.com/netbirdio/netbird/route"
|
||||||
)
|
)
|
||||||
|
|
||||||
type RouterPair struct {
|
type RouterPair struct {
|
||||||
ID route.ID
|
ID route.ID
|
||||||
Source netip.Prefix
|
Source Network
|
||||||
Destination netip.Prefix
|
Destination Network
|
||||||
Masquerade bool
|
Masquerade bool
|
||||||
Inverse bool
|
Inverse bool
|
||||||
}
|
}
|
||||||
|
|||||||
74
client/firewall/manager/set.go
Normal file
74
client/firewall/manager/set.go
Normal file
@@ -0,0 +1,74 @@
|
|||||||
|
package manager
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/sha256"
|
||||||
|
"encoding/hex"
|
||||||
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
|
"slices"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
|
)
|
||||||
|
|
||||||
|
type Set struct {
|
||||||
|
hash [4]byte
|
||||||
|
comment string
|
||||||
|
}
|
||||||
|
|
||||||
|
// String returns the string representation of the set: hashed name and comment
|
||||||
|
func (h Set) String() string {
|
||||||
|
if h.comment == "" {
|
||||||
|
return h.HashedName()
|
||||||
|
}
|
||||||
|
return h.HashedName() + ": " + h.comment
|
||||||
|
}
|
||||||
|
|
||||||
|
// HashedName returns the string representation of the hash
|
||||||
|
func (h Set) HashedName() string {
|
||||||
|
return fmt.Sprintf(
|
||||||
|
"nb-%s",
|
||||||
|
hex.EncodeToString(h.hash[:]),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Comment returns the comment of the set
|
||||||
|
func (h Set) Comment() string {
|
||||||
|
return h.comment
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewPrefixSet generates a unique name for an ipset based on the given prefixes.
|
||||||
|
func NewPrefixSet(prefixes []netip.Prefix) Set {
|
||||||
|
// sort for consistent naming
|
||||||
|
SortPrefixes(prefixes)
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
for _, src := range prefixes {
|
||||||
|
bytes, err := src.MarshalBinary()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("failed to marshal prefix %s: %v", src, err)
|
||||||
|
}
|
||||||
|
hash.Write(bytes)
|
||||||
|
}
|
||||||
|
var set Set
|
||||||
|
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||||
|
|
||||||
|
return set
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewDomainSet generates a unique name for an ipset based on the given domains.
|
||||||
|
func NewDomainSet(domains domain.List) Set {
|
||||||
|
slices.Sort(domains)
|
||||||
|
|
||||||
|
hash := sha256.New()
|
||||||
|
for _, d := range domains {
|
||||||
|
hash.Write([]byte(d.PunycodeString()))
|
||||||
|
}
|
||||||
|
set := Set{
|
||||||
|
comment: domains.SafeString(),
|
||||||
|
}
|
||||||
|
copy(set.hash[:], hash.Sum(nil)[:4])
|
||||||
|
|
||||||
|
return set
|
||||||
|
}
|
||||||
@@ -25,9 +25,10 @@ const (
|
|||||||
chainNameInputRules = "netbird-acl-input-rules"
|
chainNameInputRules = "netbird-acl-input-rules"
|
||||||
|
|
||||||
// filter chains contains the rules that jump to the rules chains
|
// filter chains contains the rules that jump to the rules chains
|
||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||||
chainNamePrerouting = "netbird-rt-prerouting"
|
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||||
|
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||||
|
|
||||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||||
)
|
)
|
||||||
@@ -84,13 +85,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *AclManager) AddPeerFiltering(
|
func (m *AclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var ipset *nftables.Set
|
var ipset *nftables.Set
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
@@ -102,7 +103,7 @@ func (m *AclManager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRules := make([]firewall.Rule, 0, 2)
|
newRules := make([]firewall.Rule, 0, 2)
|
||||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -256,7 +257,6 @@ func (m *AclManager) addIOFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipset *nftables.Set,
|
ipset *nftables.Set,
|
||||||
comment string,
|
|
||||||
) (*Rule, error) {
|
) (*Rule, error) {
|
||||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
if r, ok := m.rules[ruleId]; ok {
|
||||||
@@ -338,7 +338,7 @@ func (m *AclManager) addIOFiltering(
|
|||||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
userData := []byte(ruleId)
|
||||||
|
|
||||||
chain := m.chainInputRules
|
chain := m.chainInputRules
|
||||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||||
@@ -463,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||||
// netbird peer IP.
|
// netbird peer IP.
|
||||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||||
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
// Chain is created by route manager
|
||||||
Name: chainNamePrerouting,
|
// TODO: move creation to a common place
|
||||||
|
m.chainPrerouting = &nftables.Chain{
|
||||||
|
Name: chainNameManglePrerouting,
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
})
|
}
|
||||||
|
|
||||||
m.addFwmarkToForward(chainFwFilter)
|
m.addFwmarkToForward(chainFwFilter)
|
||||||
|
|
||||||
|
|||||||
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -129,25 +129,25 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort, dPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if !destination.Addr().Is4() {
|
if destination.IsPrefix() && !destination.Prefix.Addr().Is4() {
|
||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Prefix.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -241,7 +241,7 @@ func (m *Manager) SetLegacyManagement(isLegacy bool) error {
|
|||||||
return firewall.SetLegacyManagement(m.router, isLegacy)
|
return firewall.SetLegacyManagement(m.router, isLegacy)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close closes the firewall manager
|
||||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -358,6 +358,14 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return m.router.DeleteDNATRule(rule)
|
return m.router.DeleteDNATRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateSet updates the set with the given prefixes
|
||||||
|
func (m *Manager) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
|
m.mutex.Lock()
|
||||||
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
|
return m.router.UpdateSet(set, prefixes)
|
||||||
|
}
|
||||||
|
|
||||||
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
func (m *Manager) createWorkTable() (*nftables.Table, error) {
|
||||||
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := m.rConn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -283,12 +283,13 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := net.ParseIP("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
netip.MustParsePrefix("10.1.0.0/24"),
|
fw.Network{Prefix: netip.MustParsePrefix("10.1.0.0/24")},
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
nil,
|
nil,
|
||||||
&fw.Port{Values: []uint16{443}},
|
&fw.Port{Values: []uint16{443}},
|
||||||
@@ -297,8 +298,8 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
require.NoError(t, err, "failed to add route filtering rule")
|
require.NoError(t, err, "failed to add route filtering rule")
|
||||||
|
|
||||||
pair := fw.RouterPair{
|
pair := fw.RouterPair{
|
||||||
Source: netip.MustParsePrefix("192.168.1.0/24"),
|
Source: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
Destination: netip.MustParsePrefix("10.0.0.0/24"),
|
Destination: fw.Network{Prefix: netip.MustParsePrefix("10.0.0.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
}
|
}
|
||||||
err = manager.AddNatRule(pair)
|
err = manager.AddNatRule(pair)
|
||||||
|
|||||||
@@ -10,7 +10,6 @@ import (
|
|||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
"github.com/coreos/go-iptables/iptables"
|
"github.com/coreos/go-iptables/iptables"
|
||||||
"github.com/davecgh/go-spew/spew"
|
|
||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/binaryutil"
|
"github.com/google/nftables/binaryutil"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
@@ -20,7 +19,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
@@ -44,9 +43,14 @@ const (
|
|||||||
const refreshRulesMapError = "refresh rules map: %w"
|
const refreshRulesMapError = "refresh rules map: %w"
|
||||||
|
|
||||||
var (
|
var (
|
||||||
errFilterTableNotFound = fmt.Errorf("nftables: 'filter' table not found")
|
errFilterTableNotFound = fmt.Errorf("'filter' table not found")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
type setInput struct {
|
||||||
|
set firewall.Set
|
||||||
|
prefixes []netip.Prefix
|
||||||
|
}
|
||||||
|
|
||||||
type router struct {
|
type router struct {
|
||||||
conn *nftables.Conn
|
conn *nftables.Conn
|
||||||
workTable *nftables.Table
|
workTable *nftables.Table
|
||||||
@@ -54,7 +58,7 @@ type router struct {
|
|||||||
chains map[string]*nftables.Chain
|
chains map[string]*nftables.Chain
|
||||||
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
// rules is useful to avoid duplicates and to get missing attributes that we don't have when adding new rules
|
||||||
rules map[string]*nftables.Rule
|
rules map[string]*nftables.Rule
|
||||||
ipsetCounter *refcounter.Counter[string, []netip.Prefix, *nftables.Set]
|
ipsetCounter *refcounter.Counter[string, setInput, *nftables.Set]
|
||||||
|
|
||||||
wgIface iFaceMapper
|
wgIface iFaceMapper
|
||||||
ipFwdState *ipfwdstate.IPForwardingState
|
ipFwdState *ipfwdstate.IPForwardingState
|
||||||
@@ -100,6 +104,10 @@ func (r *router) init(workTable *nftables.Table) error {
|
|||||||
return fmt.Errorf("create containers: %w", err)
|
return fmt.Errorf("create containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.setupDataPlaneMark(); err != nil {
|
||||||
|
log.Errorf("failed to set up data plane mark: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -159,7 +167,7 @@ func (r *router) removeNatPreroutingRules() error {
|
|||||||
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
func (r *router) loadFilterTable() (*nftables.Table, error) {
|
||||||
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
tables, err := r.conn.ListTablesOfFamily(nftables.TableFamilyIPv4)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("nftables: unable to list tables: %v", err)
|
return nil, fmt.Errorf("unable to list tables: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
@@ -196,15 +204,21 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Chain is created by acl manager
|
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||||
// TODO: move creation to a common place
|
Name: chainNameManglePostrouting,
|
||||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
Table: r.workTable,
|
||||||
Name: chainNamePrerouting,
|
Hooknum: nftables.ChainHookPostrouting,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameManglePrerouting,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
}
|
})
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
// Add the single NAT rule that matches on mark
|
||||||
if err := r.addPostroutingRules(); err != nil {
|
if err := r.addPostroutingRules(); err != nil {
|
||||||
@@ -220,7 +234,83 @@ func (r *router) createContainers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
return fmt.Errorf("initialize tables: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupDataPlaneMark configures the fwmark for the data plane
|
||||||
|
func (r *router) setupDataPlaneMark() error {
|
||||||
|
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||||
|
return errors.New("no mangle chains found")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctNew := getCtNewExprs()
|
||||||
|
preExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
preExprs = append(preExprs, ctNew...)
|
||||||
|
preExprs = append(preExprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||||
|
},
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
preNftRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameManglePrerouting],
|
||||||
|
Exprs: preExprs,
|
||||||
|
}
|
||||||
|
r.conn.AddRule(preNftRule)
|
||||||
|
|
||||||
|
postExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postExprs = append(postExprs, ctNew...)
|
||||||
|
postExprs = append(postExprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||||
|
},
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
postNftRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameManglePostrouting],
|
||||||
|
Exprs: postExprs,
|
||||||
|
}
|
||||||
|
r.conn.AddRule(postNftRule)
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -228,15 +318,16 @@ func (r *router) createContainers() error {
|
|||||||
|
|
||||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination firewall.Network,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
|
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
@@ -244,23 +335,29 @@ func (r *router) AddRouteFiltering(
|
|||||||
chain := r.chains[chainNameRoutingFw]
|
chain := r.chains[chainNameRoutingFw]
|
||||||
var exprs []expr.Any
|
var exprs []expr.Any
|
||||||
|
|
||||||
|
var source firewall.Network
|
||||||
switch {
|
switch {
|
||||||
case len(sources) == 1 && sources[0].Bits() == 0:
|
case len(sources) == 1 && sources[0].Bits() == 0:
|
||||||
// If it's 0.0.0.0/0, we don't need to add any source matching
|
// If it's 0.0.0.0/0, we don't need to add any source matching
|
||||||
case len(sources) == 1:
|
case len(sources) == 1:
|
||||||
// If there's only one source, we can use it directly
|
// If there's only one source, we can use it directly
|
||||||
exprs = append(exprs, generateCIDRMatcherExpressions(true, sources[0])...)
|
source.Prefix = sources[0]
|
||||||
default:
|
default:
|
||||||
// If there are multiple sources, create or get an ipset
|
// If there are multiple sources, use a set
|
||||||
var err error
|
source.Set = firewall.NewPrefixSet(sources)
|
||||||
exprs, err = r.getIpSetExprs(sources, exprs)
|
|
||||||
if err != nil {
|
|
||||||
return nil, fmt.Errorf("get ipset expressions: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle destination
|
sourceExp, err := r.applyNetwork(source, sources, true)
|
||||||
exprs = append(exprs, generateCIDRMatcherExpressions(false, destination)...)
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, sourceExp...)
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
exprs = append(exprs, destExp...)
|
||||||
|
|
||||||
// Handle protocol
|
// Handle protocol
|
||||||
if proto != firewall.ProtocolALL {
|
if proto != firewall.ProtocolALL {
|
||||||
@@ -304,39 +401,27 @@ func (r *router) AddRouteFiltering(
|
|||||||
rule = r.conn.AddRule(rule)
|
rule = r.conn.AddRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Tracef("Adding route rule %s", spew.Sdump(rule))
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return nil, fmt.Errorf(flushError, err)
|
return nil, fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[string(ruleKey)] = rule
|
r.rules[string(ruleKey)] = rule
|
||||||
|
|
||||||
log.Debugf("nftables: added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
log.Debugf("added route rule: sources=%v, destination=%v, proto=%v, sPort=%v, dPort=%v, action=%v", sources, destination, proto, sPort, dPort, action)
|
||||||
|
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) getIpSetExprs(sources []netip.Prefix, exprs []expr.Any) ([]expr.Any, error) {
|
func (r *router) getIpSet(set firewall.Set, prefixes []netip.Prefix, isSource bool) ([]expr.Any, error) {
|
||||||
setName := firewall.GenerateSetName(sources)
|
ref, err := r.ipsetCounter.Increment(set.HashedName(), setInput{
|
||||||
ref, err := r.ipsetCounter.Increment(setName, sources)
|
set: set,
|
||||||
|
prefixes: prefixes,
|
||||||
|
})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("create or get ipset for sources: %w", err)
|
return nil, fmt.Errorf("create or get ipset: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs = append(exprs,
|
return getIpSetExprs(ref, isSource)
|
||||||
&expr.Payload{
|
|
||||||
DestRegister: 1,
|
|
||||||
Base: expr.PayloadBaseNetworkHeader,
|
|
||||||
Offset: 12,
|
|
||||||
Len: 4,
|
|
||||||
},
|
|
||||||
&expr.Lookup{
|
|
||||||
SourceRegister: 1,
|
|
||||||
SetName: ref.Out.Name,
|
|
||||||
SetID: ref.Out.ID,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
return exprs, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
@@ -355,42 +440,54 @@ func (r *router) DeleteRouteRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
return fmt.Errorf("route rule %s has no handle", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
setName := r.findSetNameInRule(nftRule)
|
|
||||||
|
|
||||||
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
if err := r.deleteNftRule(nftRule, ruleKey); err != nil {
|
||||||
return fmt.Errorf("delete: %w", err)
|
return fmt.Errorf("delete: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if setName != "" {
|
|
||||||
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
|
||||||
return fmt.Errorf("decrement ipset reference: %w", err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(nftRule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.Set, error) {
|
func (r *router) createIpSet(setName string, input setInput) (*nftables.Set, error) {
|
||||||
// overlapping prefixes will result in an error, so we need to merge them
|
// overlapping prefixes will result in an error, so we need to merge them
|
||||||
sources = firewall.MergeIPRanges(sources)
|
prefixes := firewall.MergeIPRanges(input.prefixes)
|
||||||
|
|
||||||
set := &nftables.Set{
|
nfset := &nftables.Set{
|
||||||
Name: setName,
|
Name: setName,
|
||||||
Table: r.workTable,
|
Comment: input.set.Comment(),
|
||||||
|
Table: r.workTable,
|
||||||
// required for prefixes
|
// required for prefixes
|
||||||
Interval: true,
|
Interval: true,
|
||||||
KeyType: nftables.TypeIPAddr,
|
KeyType: nftables.TypeIPAddr,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
elements := convertPrefixesToSet(prefixes)
|
||||||
|
if err := r.conn.AddSet(nfset, elements); err != nil {
|
||||||
|
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return nil, fmt.Errorf("flush error: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
||||||
|
|
||||||
|
return nfset, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func convertPrefixesToSet(prefixes []netip.Prefix) []nftables.SetElement {
|
||||||
var elements []nftables.SetElement
|
var elements []nftables.SetElement
|
||||||
for _, prefix := range sources {
|
for _, prefix := range prefixes {
|
||||||
// TODO: Implement IPv6 support
|
// TODO: Implement IPv6 support
|
||||||
if prefix.Addr().Is6() {
|
if prefix.Addr().Is6() {
|
||||||
log.Printf("Skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
log.Tracef("skipping IPv6 prefix %s: IPv6 support not yet implemented", prefix)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -406,18 +503,7 @@ func (r *router) createIpSet(setName string, sources []netip.Prefix) (*nftables.
|
|||||||
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
nftables.SetElement{Key: lastIP.AsSlice(), IntervalEnd: true},
|
||||||
)
|
)
|
||||||
}
|
}
|
||||||
|
return elements
|
||||||
if err := r.conn.AddSet(set, elements); err != nil {
|
|
||||||
return nil, fmt.Errorf("error adding elements to set %s: %w", setName, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
|
||||||
return nil, fmt.Errorf("flush error: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
log.Printf("Created new ipset: %s with %d elements", setName, len(elements)/2)
|
|
||||||
|
|
||||||
return set, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// calculateLastIP determines the last IP in a given prefix.
|
// calculateLastIP determines the last IP in a given prefix.
|
||||||
@@ -441,8 +527,8 @@ func uint32ToBytes(ip uint32) [4]byte {
|
|||||||
return b
|
return b
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
func (r *router) deleteIpSet(setName string, nfset *nftables.Set) error {
|
||||||
r.conn.DelSet(set)
|
r.conn.DelSet(nfset)
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf(flushError, err)
|
return fmt.Errorf(flushError, err)
|
||||||
}
|
}
|
||||||
@@ -451,13 +537,27 @@ func (r *router) deleteIpSet(setName string, set *nftables.Set) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) findSetNameInRule(rule *nftables.Rule) string {
|
func (r *router) decrementSetCounter(rule *nftables.Rule) error {
|
||||||
for _, e := range rule.Exprs {
|
sets := r.findSets(rule)
|
||||||
if lookup, ok := e.(*expr.Lookup); ok {
|
|
||||||
return lookup.SetName
|
var merr *multierror.Error
|
||||||
|
for _, setName := range sets {
|
||||||
|
if _, err := r.ipsetCounter.Decrement(setName); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("decrement set counter: %w", err))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return ""
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) findSets(rule *nftables.Rule) []string {
|
||||||
|
var sets []string
|
||||||
|
for _, e := range rule.Exprs {
|
||||||
|
if lookup, ok := e.(*expr.Lookup); ok {
|
||||||
|
sets = append(sets, lookup.SetName)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return sets
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
func (r *router) deleteNftRule(rule *nftables.Rule, ruleKey string) error {
|
||||||
@@ -499,7 +599,8 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("nftables: insert rules for %s: %v", pair.Destination, err)
|
// TODO: rollback ipset counter
|
||||||
|
return fmt.Errorf("insert rules for %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -507,8 +608,15 @@ func (r *router) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// addNatRule inserts a nftables rule to the conn client flush queue
|
// addNatRule inserts a nftables rule to the conn client flush queue
|
||||||
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
func (r *router) addNatRule(pair firewall.RouterPair) error {
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
op := expr.CmpOpEq
|
op := expr.CmpOpEq
|
||||||
if pair.Inverse {
|
if pair.Inverse {
|
||||||
@@ -516,26 +624,6 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
|
||||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
|
||||||
&expr.Ct{
|
|
||||||
Key: expr.CtKeySTATE,
|
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 1,
|
|
||||||
DestRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
|
||||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
|
|
||||||
// interface matching
|
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
Key: expr.MetaKeyIIFNAME,
|
Key: expr.MetaKeyIIFNAME,
|
||||||
Register: 1,
|
Register: 1,
|
||||||
@@ -546,6 +634,9 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
Data: ifname(r.wgIface.Name()),
|
Data: ifname(r.wgIface.Name()),
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||||
|
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||||
|
exprs = append(exprs, getCtNewExprs()...)
|
||||||
|
|
||||||
exprs = append(exprs, sourceExp...)
|
exprs = append(exprs, sourceExp...)
|
||||||
exprs = append(exprs, destExp...)
|
exprs = append(exprs, destExp...)
|
||||||
@@ -575,9 +666,11 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
// Ensure nat rules come first, so the mark can be overwritten.
|
||||||
|
// Currently overwritten by the dst-type LOCAL rules for redirected traffic.
|
||||||
|
r.rules[ruleKey] = r.conn.InsertRule(&nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNamePrerouting],
|
Chain: r.chains[chainNameManglePrerouting],
|
||||||
Exprs: exprs,
|
Exprs: exprs,
|
||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
})
|
})
|
||||||
@@ -658,8 +751,15 @@ func (r *router) addPostroutingRules() error {
|
|||||||
|
|
||||||
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
// addLegacyRouteRule adds a legacy routing rule for mgmt servers pre route acls
|
||||||
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, pair.Source)
|
sourceExp, err := r.applyNetwork(pair.Source, nil, true)
|
||||||
destExp := generateCIDRMatcherExpressions(false, pair.Destination)
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply source: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
destExp, err := r.applyNetwork(pair.Destination, nil, false)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("apply destination: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
exprs := []expr.Any{
|
exprs := []expr.Any{
|
||||||
&expr.Counter{},
|
&expr.Counter{},
|
||||||
@@ -668,7 +768,8 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
expression := append(sourceExp, append(destExp, exprs...)...) // nolint:gocritic
|
exprs = append(exprs, sourceExp...)
|
||||||
|
exprs = append(exprs, destExp...)
|
||||||
|
|
||||||
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.ForwardingFormat, pair)
|
||||||
|
|
||||||
@@ -681,7 +782,7 @@ func (r *router) addLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNameRoutingFw],
|
Chain: r.chains[chainNameRoutingFw],
|
||||||
Exprs: expression,
|
Exprs: exprs,
|
||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
})
|
})
|
||||||
return nil
|
return nil
|
||||||
@@ -696,11 +797,13 @@ func (r *router) removeLegacyRouteRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove legacy forwarding rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
log.Debugf("removed legacy forwarding rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
} else {
|
|
||||||
log.Debugf("nftables: legacy forwarding rule %s not found", ruleKey)
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -911,12 +1014,14 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
return fmt.Errorf(refreshRulesMapError, err)
|
return fmt.Errorf(refreshRulesMapError, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(pair); err != nil {
|
if pair.Masquerade {
|
||||||
return fmt.Errorf("remove prerouting rule: %w", err)
|
if err := r.removeNatRule(pair); err != nil {
|
||||||
}
|
return fmt.Errorf("remove prerouting rule: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
if err := r.removeNatRule(firewall.GetInversePair(pair)); err != nil {
|
||||||
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
return fmt.Errorf("remove inverse prerouting rule: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.removeLegacyRouteRule(pair); err != nil {
|
if err := r.removeLegacyRouteRule(pair); err != nil {
|
||||||
@@ -924,10 +1029,10 @@ func (r *router) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("nftables: received error while applying rule removal for %s: %v", pair.Destination, err)
|
// TODO: rollback set counter
|
||||||
|
return fmt.Errorf("remove nat rules rule %s: %v", pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed nat rules for %s", pair.Destination)
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -935,16 +1040,19 @@ func (r *router) removeNatRule(pair firewall.RouterPair) error {
|
|||||||
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
ruleKey := firewall.GenKey(firewall.PreroutingFormat, pair)
|
||||||
|
|
||||||
if rule, exists := r.rules[ruleKey]; exists {
|
if rule, exists := r.rules[ruleKey]; exists {
|
||||||
err := r.conn.DelRule(rule)
|
if err := r.conn.DelRule(rule); err != nil {
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
return fmt.Errorf("remove prerouting rule %s -> %s: %v", pair.Source, pair.Destination, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Debugf("nftables: removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
log.Debugf("removed prerouting rule %s -> %s", pair.Source, pair.Destination)
|
||||||
|
|
||||||
delete(r.rules, ruleKey)
|
delete(r.rules, ruleKey)
|
||||||
|
|
||||||
|
if err := r.decrementSetCounter(rule); err != nil {
|
||||||
|
return fmt.Errorf("decrement set counter: %w", err)
|
||||||
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Debugf("nftables: prerouting rule %s not found", ruleKey)
|
log.Debugf("prerouting rule %s not found", ruleKey)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -956,7 +1064,7 @@ func (r *router) refreshRulesMap() error {
|
|||||||
for _, chain := range r.chains {
|
for _, chain := range r.chains {
|
||||||
rules, err := r.conn.GetRules(chain.Table, chain)
|
rules, err := r.conn.GetRules(chain.Table, chain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("nftables: unable to list rules: %v", err)
|
return fmt.Errorf(" unable to list rules: %v", err)
|
||||||
}
|
}
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 {
|
if len(rule.UserData) > 0 {
|
||||||
@@ -1230,13 +1338,54 @@ func (r *router) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
// generateCIDRMatcherExpressions generates nftables expressions that matches a CIDR
|
func (r *router) UpdateSet(set firewall.Set, prefixes []netip.Prefix) error {
|
||||||
func generateCIDRMatcherExpressions(source bool, prefix netip.Prefix) []expr.Any {
|
nfset, err := r.conn.GetSetByName(r.workTable, set.HashedName())
|
||||||
var offset uint32
|
if err != nil {
|
||||||
if source {
|
return fmt.Errorf("get set %s: %w", set.HashedName(), err)
|
||||||
offset = 12 // src offset
|
}
|
||||||
} else {
|
|
||||||
offset = 16 // dst offset
|
elements := convertPrefixesToSet(prefixes)
|
||||||
|
if err := r.conn.SetAddElements(nfset, elements); err != nil {
|
||||||
|
return fmt.Errorf("add elements to set %s: %w", set.HashedName(), err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf(flushError, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("updated set %s with prefixes %v", set.HashedName(), prefixes)
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyNetwork generates nftables expressions for networks (CIDR) or sets
|
||||||
|
func (r *router) applyNetwork(
|
||||||
|
network firewall.Network,
|
||||||
|
setPrefixes []netip.Prefix,
|
||||||
|
isSource bool,
|
||||||
|
) ([]expr.Any, error) {
|
||||||
|
if network.IsSet() {
|
||||||
|
exprs, err := r.getIpSet(network.Set, setPrefixes, isSource)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("source: %w", err)
|
||||||
|
}
|
||||||
|
return exprs, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
if network.IsPrefix() {
|
||||||
|
return applyPrefix(network.Prefix, isSource), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyPrefix generates nftables expressions for a CIDR prefix
|
||||||
|
func applyPrefix(prefix netip.Prefix, isSource bool) []expr.Any {
|
||||||
|
// dst offset
|
||||||
|
offset := uint32(16)
|
||||||
|
if isSource {
|
||||||
|
// src offset
|
||||||
|
offset = 12
|
||||||
}
|
}
|
||||||
|
|
||||||
ones := prefix.Bits()
|
ones := prefix.Bits()
|
||||||
@@ -1323,3 +1472,48 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
|
|
||||||
return exprs
|
return exprs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getCtNewExprs() []expr.Any {
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: 1,
|
||||||
|
DestRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func getIpSetExprs(ref refcounter.Ref[*nftables.Set], isSource bool) ([]expr.Any, error) {
|
||||||
|
|
||||||
|
// dst offset
|
||||||
|
offset := uint32(16)
|
||||||
|
if isSource {
|
||||||
|
// src offset
|
||||||
|
offset = 12
|
||||||
|
}
|
||||||
|
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Payload{
|
||||||
|
DestRegister: 1,
|
||||||
|
Base: expr.PayloadBaseNetworkHeader,
|
||||||
|
Offset: offset,
|
||||||
|
Len: 4,
|
||||||
|
},
|
||||||
|
&expr.Lookup{
|
||||||
|
SourceRegister: 1,
|
||||||
|
SetName: ref.Out.Name,
|
||||||
|
SetID: ref.Out.ID,
|
||||||
|
},
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -88,8 +88,8 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Build CIDR matching expressions
|
// Build CIDR matching expressions
|
||||||
sourceExp := generateCIDRMatcherExpressions(true, testCase.InputPair.Source)
|
sourceExp := applyPrefix(testCase.InputPair.Source.Prefix, true)
|
||||||
destExp := generateCIDRMatcherExpressions(false, testCase.InputPair.Destination)
|
destExp := applyPrefix(testCase.InputPair.Destination.Prefix, false)
|
||||||
|
|
||||||
// Combine all expressions in the correct order
|
// Combine all expressions in the correct order
|
||||||
// nolint:gocritic
|
// nolint:gocritic
|
||||||
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := 0
|
found := 0
|
||||||
for _, chain := range rtr.chains {
|
for _, chain := range rtr.chains {
|
||||||
if chain.Name == chainNamePrerouting {
|
if chain.Name == chainNameManglePrerouting {
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
// Verify the rule was added
|
// Verify the rule was added
|
||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := false
|
found := false
|
||||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules")
|
require.NoError(t, err, "should list rules")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
|
|
||||||
// Verify the rule was removed
|
// Verify the rule was removed
|
||||||
found = false
|
found = false
|
||||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules after removal")
|
require.NoError(t, err, "should list rules after removal")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, firewall.Network{Prefix: tt.destination}, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
@@ -441,8 +441,8 @@ func TestNftablesCreateIpSet(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
setName := firewall.GenerateSetName(tt.sources)
|
setName := firewall.NewPrefixSet(tt.sources).HashedName()
|
||||||
set, err := r.createIpSet(setName, tt.sources)
|
set, err := r.createIpSet(setName, setInput{prefixes: tt.sources})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Logf("Failed to create IP set: %v", err)
|
t.Logf("Failed to create IP set: %v", err)
|
||||||
printNftSets()
|
printNftSets()
|
||||||
|
|||||||
@@ -15,8 +15,8 @@ var (
|
|||||||
Name: "Insert Forwarding IPV4 Rule",
|
Name: "Insert Forwarding IPV4 Rule",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
Masquerade: false,
|
Masquerade: false,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -24,8 +24,8 @@ var (
|
|||||||
Name: "Insert Forwarding And Nat IPV4 Rules",
|
Name: "Insert Forwarding And Nat IPV4 Rules",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -40,8 +40,8 @@ var (
|
|||||||
Name: "Remove Forwarding And Nat IPV4 Rules",
|
Name: "Remove Forwarding And Nat IPV4 Rules",
|
||||||
InputPair: firewall.RouterPair{
|
InputPair: firewall.RouterPair{
|
||||||
ID: "zxa",
|
ID: "zxa",
|
||||||
Source: netip.MustParsePrefix("100.100.100.1/32"),
|
Source: firewall.Network{Prefix: netip.MustParsePrefix("100.100.100.1/32")},
|
||||||
Destination: netip.MustParsePrefix("100.100.200.0/24"),
|
Destination: firewall.Network{Prefix: netip.MustParsePrefix("100.100.200.0/24")},
|
||||||
Masquerade: true,
|
Masquerade: true,
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -11,13 +12,13 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Reset firewall to the default state
|
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||||
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
@@ -31,8 +32,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
@@ -20,13 +21,13 @@ const (
|
|||||||
firewallRuleName = "Netbird"
|
firewallRuleName = "Netbird"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Close closes the firewall manager
|
// Close cleans up the firewall manager by removing all rules and closing trackers
|
||||||
func (m *Manager) Close(*statemanager.Manager) error {
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
@@ -40,8 +41,8 @@ func (m *Manager) Close(*statemanager.Manager) error {
|
|||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
|||||||
@@ -1,20 +1,27 @@
|
|||||||
// common.go
|
|
||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"fmt"
|
||||||
"sync"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaseConnTrack provides common fields and locking for all connection types
|
// BaseConnTrack provides common fields and locking for all connection types
|
||||||
type BaseConnTrack struct {
|
type BaseConnTrack struct {
|
||||||
SourceIP net.IP
|
FlowId uuid.UUID
|
||||||
DestIP net.IP
|
Direction nftypes.Direction
|
||||||
SourcePort uint16
|
SourceIP netip.Addr
|
||||||
DestPort uint16
|
DestIP netip.Addr
|
||||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
lastSeen atomic.Int64
|
||||||
|
PacketsTx atomic.Uint64
|
||||||
|
PacketsRx atomic.Uint64
|
||||||
|
BytesTx atomic.Uint64
|
||||||
|
BytesRx atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// these small methods will be inlined by the compiler
|
// these small methods will be inlined by the compiler
|
||||||
@@ -24,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
|||||||
b.lastSeen.Store(time.Now().UnixNano())
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateCounters safely updates the packet and byte counters
|
||||||
|
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
|
||||||
|
if direction == nftypes.Egress {
|
||||||
|
b.PacketsTx.Add(1)
|
||||||
|
b.BytesTx.Add(uint64(bytes))
|
||||||
|
} else {
|
||||||
|
b.PacketsRx.Add(1)
|
||||||
|
b.BytesRx.Add(uint64(bytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetLastSeen safely gets the last seen timestamp
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
return time.Unix(0, b.lastSeen.Load())
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
@@ -35,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
|||||||
return time.Since(lastSeen) > timeout
|
return time.Since(lastSeen) > timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPAddr is a fixed-size IP address to avoid allocations
|
|
||||||
type IPAddr [16]byte
|
|
||||||
|
|
||||||
// MakeIPAddr creates an IPAddr from net.IP
|
|
||||||
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
|
||||||
// Optimization: check for v4 first as it's more common
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
copy(addr[12:], ip4)
|
|
||||||
} else {
|
|
||||||
copy(addr[:], ip.To16())
|
|
||||||
}
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnKey uniquely identifies a connection
|
// ConnKey uniquely identifies a connection
|
||||||
type ConnKey struct {
|
type ConnKey struct {
|
||||||
SrcIP IPAddr
|
SrcIP netip.Addr
|
||||||
DstIP IPAddr
|
DstIP netip.Addr
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeConnKey creates a connection key
|
func (c ConnKey) String() string {
|
||||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
return ConnKey{
|
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
|
||||||
DstIP: MakeIPAddr(dstIP),
|
|
||||||
SrcPort: srcPort,
|
|
||||||
DstPort: dstPort,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateIPs checks if IPs match without allocation
|
|
||||||
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
|
||||||
if ip4 := pktIP.To4(); ip4 != nil {
|
|
||||||
// Compare IPv4 addresses (last 4 bytes)
|
|
||||||
for i := 0; i < 4; i++ {
|
|
||||||
if connIP[12+i] != ip4[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Compare full IPv6 addresses
|
|
||||||
ip6 := pktIP.To16()
|
|
||||||
for i := 0; i < 16; i++ {
|
|
||||||
if connIP[i] != ip6[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
|
||||||
type PreallocatedIPs struct {
|
|
||||||
sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPreallocatedIPs creates a new IP pool
|
|
||||||
func NewPreallocatedIPs() *PreallocatedIPs {
|
|
||||||
return &PreallocatedIPs{
|
|
||||||
Pool: sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
ip := make(net.IP, 16)
|
|
||||||
return &ip
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get retrieves an IP from the pool
|
|
||||||
func (p *PreallocatedIPs) Get() net.IP {
|
|
||||||
return *p.Pool.Get().(*net.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put returns an IP to the pool
|
|
||||||
func (p *PreallocatedIPs) Put(ip net.IP) {
|
|
||||||
p.Pool.Put(&ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// copyIP copies an IP address efficiently
|
|
||||||
func copyIP(dst, src net.IP) {
|
|
||||||
if len(src) == 16 {
|
|
||||||
copy(dst, src)
|
|
||||||
} else {
|
|
||||||
// Handle IPv4
|
|
||||||
copy(dst[12:], src.To4())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,94 +1,66 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
func BenchmarkIPOperations(b *testing.B) {
|
|
||||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = MakeIPAddr(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("ValidateIPs", func(b *testing.B) {
|
|
||||||
ip1 := net.ParseIP("192.168.1.1")
|
|
||||||
ip2 := net.ParseIP("192.168.1.1")
|
|
||||||
addr := MakeIPAddr(ip1)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = ValidateIPs(addr, ip2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IPPool", func(b *testing.B) {
|
|
||||||
pool := NewPreallocatedIPs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
ip := pool.Get()
|
|
||||||
pool.Put(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
srcIdx := i % len(srcIPs)
|
srcIdx := i % len(srcIPs)
|
||||||
dstIdx := (i + 1) % len(dstIPs)
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
|
||||||
// Simulate some valid inbound packets
|
// Simulate some valid inbound packets
|
||||||
if i%3 == 0 {
|
if i%3 == 0 {
|
||||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
srcIdx := i % len(srcIPs)
|
srcIdx := i % len(srcIPs)
|
||||||
dstIdx := (i + 1) % len(dstIPs)
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
|
||||||
|
|
||||||
// Simulate some valid inbound packets
|
// Simulate some valid inbound packets
|
||||||
if i%3 == 0 {
|
if i%3 == 0 {
|
||||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -20,18 +23,20 @@ const (
|
|||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
type ICMPConnKey struct {
|
type ICMPConnKey struct {
|
||||||
// Supports both IPv4 and IPv6
|
SrcIP netip.Addr
|
||||||
SrcIP [16]byte
|
DstIP netip.Addr
|
||||||
DstIP [16]byte
|
ID uint16
|
||||||
Sequence uint16 // ICMP sequence number
|
}
|
||||||
ID uint16 // ICMP identifier
|
|
||||||
|
func (i ICMPConnKey) String() string {
|
||||||
|
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPConnTrack represents an ICMP connection state
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
type ICMPConnTrack struct {
|
type ICMPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
Sequence uint16
|
ICMPType uint8
|
||||||
ID uint16
|
ICMPCode uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPTracker manages ICMP connection states
|
// ICMPTracker manages ICMP connection states
|
||||||
@@ -42,11 +47,11 @@ type ICMPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultICMPTimeout
|
timeout = DefaultICMPTimeout
|
||||||
}
|
}
|
||||||
@@ -59,67 +64,108 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine(ctx)
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP Echo Request
|
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
key := ICMPConnKey{
|
||||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
t.mutex.Lock()
|
ID: id,
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &ICMPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
},
|
|
||||||
ID: id,
|
|
||||||
Sequence: seq,
|
|
||||||
}
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New ICMP connection %v", key)
|
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
|
||||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound ICMP connection
|
||||||
|
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
|
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
|
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
typ, code := typecode.Type(), typecode.Code()
|
||||||
|
|
||||||
|
// non echo requests don't need tracking
|
||||||
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &ICMPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
|
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||||
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
key := ICMPConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
ID: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
conn.UpdateLastSeen()
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
conn.ID == id &&
|
|
||||||
conn.Sequence == seq
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
@@ -134,17 +180,18 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanup() {
|
func (t *ICMPTracker) cleanup() {
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -154,20 +201,46 @@ func (t *ICMPTracker) Close() {
|
|||||||
t.tickerCancel()
|
t.tickerCancel()
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeICMPKey creates an ICMP connection key
|
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
|
||||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
return ICMPConnKey{
|
FlowID: conn.FlowId,
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
Type: typ,
|
||||||
DstIP: MakeIPAddr(dstIP),
|
RuleID: ruleID,
|
||||||
ID: id,
|
Direction: conn.Direction,
|
||||||
Sequence: seq,
|
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||||
}
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
ICMPType: conn.ICMPType,
|
||||||
|
ICMPCode: conn.ICMPCode,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeStart,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: direction,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
fields.RxPackets = 1
|
||||||
|
fields.RxBytes = uint64(size)
|
||||||
|
} else {
|
||||||
|
fields.TxPackets = 1
|
||||||
|
fields.TxBytes = uint64(size)
|
||||||
|
}
|
||||||
|
t.flowLogger.StoreEvent(fields)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,39 +1,39 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -20,11 +23,11 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TCPSyn uint8 = 0x02
|
|
||||||
TCPAck uint8 = 0x10
|
|
||||||
TCPFin uint8 = 0x01
|
TCPFin uint8 = 0x01
|
||||||
|
TCPSyn uint8 = 0x02
|
||||||
TCPRst uint8 = 0x04
|
TCPRst uint8 = 0x04
|
||||||
TCPPush uint8 = 0x08
|
TCPPush uint8 = 0x08
|
||||||
|
TCPAck uint8 = 0x10
|
||||||
TCPUrg uint8 = 0x20
|
TCPUrg uint8 = 0x20
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,7 +41,36 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// TCPState represents the state of a TCP connection
|
// TCPState represents the state of a TCP connection
|
||||||
type TCPState int
|
type TCPState int32
|
||||||
|
|
||||||
|
func (s TCPState) String() string {
|
||||||
|
switch s {
|
||||||
|
case TCPStateNew:
|
||||||
|
return "New"
|
||||||
|
case TCPStateSynSent:
|
||||||
|
return "SYN Sent"
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
return "SYN Received"
|
||||||
|
case TCPStateEstablished:
|
||||||
|
return "Established"
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
return "FIN Wait 1"
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
return "FIN Wait 2"
|
||||||
|
case TCPStateClosing:
|
||||||
|
return "Closing"
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
return "Time Wait"
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
return "Close Wait"
|
||||||
|
case TCPStateLastAck:
|
||||||
|
return "Last ACK"
|
||||||
|
case TCPStateClosed:
|
||||||
|
return "Closed"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TCPStateNew TCPState = iota
|
TCPStateNew TCPState = iota
|
||||||
@@ -54,30 +86,38 @@ const (
|
|||||||
TCPStateClosed
|
TCPStateClosed
|
||||||
)
|
)
|
||||||
|
|
||||||
// TCPConnKey uniquely identifies a TCP connection
|
|
||||||
type TCPConnKey struct {
|
|
||||||
SrcIP [16]byte
|
|
||||||
DstIP [16]byte
|
|
||||||
SrcPort uint16
|
|
||||||
DstPort uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// TCPConnTrack represents a TCP connection state
|
// TCPConnTrack represents a TCP connection state
|
||||||
type TCPConnTrack struct {
|
type TCPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
State TCPState
|
SourcePort uint16
|
||||||
established atomic.Bool
|
DestPort uint16
|
||||||
sync.RWMutex
|
state atomic.Int32
|
||||||
|
tombstone atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEstablished safely checks if connection is established
|
// GetState safely retrieves the current state
|
||||||
func (t *TCPConnTrack) IsEstablished() bool {
|
func (t *TCPConnTrack) GetState() TCPState {
|
||||||
return t.established.Load()
|
return TCPState(t.state.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetEstablished safely sets the established state
|
// SetState safely updates the current state
|
||||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
func (t *TCPConnTrack) SetState(state TCPState) {
|
||||||
t.established.Store(state)
|
t.state.Store(int32(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareAndSwapState atomically changes the state from old to new if current == old
|
||||||
|
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
|
||||||
|
return t.state.CompareAndSwap(int32(old), int32(newState))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTombstone safely checks if the connection is marked for deletion
|
||||||
|
func (t *TCPConnTrack) IsTombstone() bool {
|
||||||
|
return t.tombstone.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTombstone safely marks the connection for deletion
|
||||||
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
|
t.tombstone.Store(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TCPTracker manages TCP connection states
|
// TCPTracker manages TCP connection states
|
||||||
@@ -88,11 +128,18 @@ type TCPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
ipPool *PreallocatedIPs
|
waitTimeout time.Duration
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||||
|
waitTimeout := TimeWaitTimeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultTCPTimeout
|
||||||
|
} else {
|
||||||
|
waitTimeout = timeout / 45
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
@@ -102,179 +149,211 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
|||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
ipPool: NewPreallocatedIPs(),
|
waitTimeout: waitTimeout,
|
||||||
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine(ctx)
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
key := ConnKey{
|
||||||
// Create key before lock
|
SrcIP: srcIP,
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
t.mutex.Lock()
|
DstPort: dstPort,
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
// Use preallocated IPs
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &TCPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
SourcePort: srcPort,
|
|
||||||
DestPort: dstPort,
|
|
||||||
},
|
|
||||||
State: TCPStateNew,
|
|
||||||
}
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
conn.established.Store(false)
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
// Lock individual connection for state update
|
|
||||||
conn.Lock()
|
|
||||||
t.updateState(conn, flags, true)
|
|
||||||
conn.Unlock()
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
|
||||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
|
||||||
if !isValidFlagCombination(flags) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
return false
|
t.updateState(key, conn, flags, direction, size)
|
||||||
|
return key, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle RST packets
|
return key, false
|
||||||
if flags&TCPRst != 0 {
|
|
||||||
conn.Lock()
|
|
||||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetEstablished(false)
|
|
||||||
conn.Unlock()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
conn.Unlock()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.Lock()
|
|
||||||
t.updateState(conn, flags, false)
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
isEstablished := conn.IsEstablished()
|
|
||||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
|
||||||
conn.Unlock()
|
|
||||||
|
|
||||||
return isEstablished || isValidState
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateState updates the TCP connection state based on flags
|
// TrackOutbound records an outbound TCP connection
|
||||||
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||||
// Handle RST flag specially - it always causes transition to closed
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||||
if flags&TCPRst != 0 {
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
conn.State = TCPStateClosed
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||||
conn.SetEstablished(false)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||||
|
if exists || flags&TCPSyn == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch conn.State {
|
conn := &TCPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.tombstone.Store(false)
|
||||||
|
conn.state.Store(int32(TCPStateNew))
|
||||||
|
|
||||||
|
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||||
|
t.updateState(key, conn, flags, direction, size)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
|
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.IsTombstone() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState := conn.GetState()
|
||||||
|
if !t.isValidStateForFlags(currentState, flags) {
|
||||||
|
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||||
|
// allow all flags for established for now
|
||||||
|
if currentState == TCPStateEstablished {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
t.updateState(key, conn, flags, nftypes.Ingress, size)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateState updates the TCP connection state based on flags
|
||||||
|
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(packetDir, size)
|
||||||
|
|
||||||
|
currentState := conn.GetState()
|
||||||
|
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||||
|
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var newState TCPState
|
||||||
|
switch currentState {
|
||||||
case TCPStateNew:
|
case TCPStateNew:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||||
conn.State = TCPStateSynSent
|
if conn.Direction == nftypes.Egress {
|
||||||
|
newState = TCPStateSynSent
|
||||||
|
} else {
|
||||||
|
newState = TCPStateSynReceived
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateSynSent:
|
case TCPStateSynSent:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||||
if isOutbound {
|
if packetDir != conn.Direction {
|
||||||
conn.State = TCPStateSynReceived
|
newState = TCPStateEstablished
|
||||||
} else {
|
} else {
|
||||||
// Simultaneous open
|
// Simultaneous open
|
||||||
conn.State = TCPStateEstablished
|
newState = TCPStateSynReceived
|
||||||
conn.SetEstablished(true)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateSynReceived:
|
case TCPStateSynReceived:
|
||||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||||
conn.State = TCPStateEstablished
|
if packetDir == conn.Direction {
|
||||||
conn.SetEstablished(true)
|
newState = TCPStateEstablished
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateEstablished:
|
case TCPStateEstablished:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
if isOutbound {
|
if packetDir == conn.Direction {
|
||||||
conn.State = TCPStateFinWait1
|
newState = TCPStateFinWait1
|
||||||
} else {
|
} else {
|
||||||
conn.State = TCPStateCloseWait
|
newState = TCPStateCloseWait
|
||||||
}
|
}
|
||||||
conn.SetEstablished(false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait1:
|
case TCPStateFinWait1:
|
||||||
switch {
|
if packetDir != conn.Direction {
|
||||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
switch {
|
||||||
// Simultaneous close - both sides sent FIN
|
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||||
conn.State = TCPStateClosing
|
newState = TCPStateClosing
|
||||||
case flags&TCPFin != 0:
|
case flags&TCPFin != 0:
|
||||||
conn.State = TCPStateFinWait2
|
newState = TCPStateClosing
|
||||||
case flags&TCPAck != 0:
|
case flags&TCPAck != 0:
|
||||||
conn.State = TCPStateFinWait2
|
newState = TCPStateFinWait2
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait2:
|
case TCPStateFinWait2:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
newState = TCPStateTimeWait
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateClosing:
|
case TCPStateClosing:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
newState = TCPStateTimeWait
|
||||||
// Keep established = false from previous state
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
conn.State = TCPStateLastAck
|
newState = TCPStateLastAck
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateClosed
|
newState = TCPStateClosed
|
||||||
|
|
||||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
case TCPStateTimeWait:
|
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||||
// This is handled by the cleanup routine
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
switch newState {
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
case TCPStateTimeWait:
|
||||||
|
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
|
||||||
|
case TCPStateClosed:
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
|||||||
if !isValidFlagCombination(flags) {
|
if !isValidFlagCombination(flags) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
if state == TCPStateSynSent {
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
switch state {
|
switch state {
|
||||||
case TCPStateNew:
|
case TCPStateNew:
|
||||||
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||||
case TCPStateSynSent:
|
case TCPStateSynSent:
|
||||||
|
// TODO: support simultaneous open
|
||||||
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||||
case TCPStateSynReceived:
|
case TCPStateSynReceived:
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
case TCPStateEstablished:
|
case TCPStateEstablished:
|
||||||
if flags&TCPRst != 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
case TCPStateFinWait1:
|
case TCPStateFinWait1:
|
||||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
@@ -311,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
|||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
case TCPStateClosed:
|
case TCPStateClosed:
|
||||||
// Accept retransmitted ACKs in closed state
|
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
|
||||||
// This is important because the final ACK might be lost
|
|
||||||
// and the peer will retransmit their FIN-ACK
|
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -337,24 +418,33 @@ func (t *TCPTracker) cleanup() {
|
|||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
|
if conn.IsTombstone() {
|
||||||
|
// Clean up tombstoned connections without sending an event
|
||||||
|
delete(t.connections, key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
switch {
|
currentState := conn.GetState()
|
||||||
case conn.State == TCPStateTimeWait:
|
switch currentState {
|
||||||
timeout = TimeWaitTimeout
|
case TCPStateTimeWait:
|
||||||
case conn.IsEstablished():
|
timeout = t.waitTimeout
|
||||||
|
case TCPStateEstablished:
|
||||||
timeout = t.timeout
|
timeout = t.timeout
|
||||||
default:
|
default:
|
||||||
timeout = TCPHandshakeTimeout
|
timeout = TCPHandshakeTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
lastSeen := conn.GetLastSeen()
|
if conn.timeoutExceeded(timeout) {
|
||||||
if time.Since(lastSeen) > timeout {
|
|
||||||
// Return IPs to pool
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||||
|
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
|
||||||
|
// event already handled by state change
|
||||||
|
if currentState != TCPStateTimeWait {
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -365,10 +455,6 @@ func (t *TCPTracker) Close() {
|
|||||||
|
|
||||||
// Clean up all remaining IPs
|
// Clean up all remaining IPs
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
@@ -386,3 +472,21 @@ func isValidFlagCombination(flags uint8) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
if i%2 == 0 {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
} else {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark connection cleanup
|
||||||
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Pre-populate with expired connections
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for connections to expire
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.cleanup()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,19 +1,20 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTCPStateMachine(t *testing.T) {
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@@ -58,7 +59,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
|
||||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -76,17 +77,17 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Send initial SYN
|
// Send initial SYN
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
// Receive SYN-ACK
|
// Receive SYN-ACK
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
require.True(t, valid, "SYN-ACK should be allowed")
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
// Send ACK
|
// Send ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
// Test data transfer
|
// Test data transfer
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
|
||||||
require.True(t, valid, "Data should be allowed after handshake")
|
require.True(t, valid, "Data should be allowed after handshake")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -99,18 +100,18 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Send FIN
|
// Send FIN
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
// Receive ACK for FIN
|
// Receive ACK for FIN
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
require.True(t, valid, "ACK for FIN should be allowed")
|
require.True(t, valid, "ACK for FIN should be allowed")
|
||||||
|
|
||||||
// Receive FIN from other side
|
// Receive FIN from other side
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
require.True(t, valid, "FIN should be allowed")
|
require.True(t, valid, "FIN should be allowed")
|
||||||
|
|
||||||
// Send final ACK
|
// Send final ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -122,11 +123,8 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Receive RST
|
// Receive RST
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
require.True(t, valid, "RST should be allowed for established connection")
|
require.True(t, valid, "RST should be allowed for established connection")
|
||||||
|
|
||||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
|
||||||
// The connection will be cleaned up by timeout
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -138,13 +136,13 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Both sides send FIN+ACK
|
// Both sides send FIN+ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||||
|
|
||||||
// Both sides send final ACK
|
// Both sides send final ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
require.True(t, valid, "Final ACKs should be allowed")
|
require.True(t, valid, "Final ACKs should be allowed")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -154,7 +152,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
tt.test(t)
|
tt.test(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -162,11 +160,11 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRSTHandling(t *testing.T) {
|
func TestRSTHandling(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@@ -181,12 +179,12 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
name: "RST in established",
|
name: "RST in established",
|
||||||
setupState: func() {
|
setupState: func() {
|
||||||
// Establish connection first
|
// Establish connection first
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
},
|
},
|
||||||
sendRST: func() {
|
sendRST: func() {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
},
|
},
|
||||||
wantValid: true,
|
wantValid: true,
|
||||||
desc: "Should accept RST for established connection",
|
desc: "Should accept RST for established connection",
|
||||||
@@ -195,7 +193,7 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
name: "RST without connection",
|
name: "RST without connection",
|
||||||
setupState: func() {},
|
setupState: func() {},
|
||||||
sendRST: func() {
|
sendRST: func() {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
},
|
},
|
||||||
wantValid: false,
|
wantValid: false,
|
||||||
desc: "Should reject RST without connection",
|
desc: "Should reject RST without connection",
|
||||||
@@ -208,101 +206,455 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
tt.sendRST()
|
tt.sendRST()
|
||||||
|
|
||||||
// Verify connection state is as expected
|
// Verify connection state is as expected
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn := tracker.connections[key]
|
conn := tracker.connections[key]
|
||||||
if tt.wantValid {
|
if tt.wantValid {
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
require.Equal(t, TCPStateClosed, conn.State)
|
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||||
require.False(t, conn.IsEstablished())
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTCPRetransmissions(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Test SYN retransmission
|
||||||
|
t.Run("SYN Retransmission", func(t *testing.T) {
|
||||||
|
// Initial SYN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Retransmit SYN (should not affect the state machine)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Verify we're still in SYN-SENT state
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||||
|
|
||||||
|
// Complete the handshake
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// Verify we're in ESTABLISHED state
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test ACK retransmission in established state
|
||||||
|
t.Run("ACK Retransmission", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Retransmit ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// State should remain ESTABLISHED
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test FIN retransmission
|
||||||
|
t.Run("FIN Retransmission", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Send FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Retransmit FIN (should not change state)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Receive ACK for FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPDataTransfer(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Data Transfer", func(t *testing.T) {
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Send data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||||
|
|
||||||
|
// Receive ACK for data
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Receive data
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Send ACK for received data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
|
|
||||||
|
// State should remain ESTABLISHED
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
|
||||||
|
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
|
||||||
|
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
|
||||||
|
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPHalfClosedConnections(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Test half-closed connection: local end closes, remote end continues sending data
|
||||||
|
t.Run("Local Close, Remote Data", func(t *testing.T) {
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Send FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Receive ACK for FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
|
||||||
|
// Remote end can still send data
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// We can still ACK their data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// Receive FIN from remote end
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Send final ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// State should remain TIME-WAIT (waiting for possible retransmissions)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test half-closed connection: remote end closes, local end continues sending data
|
||||||
|
t.Run("Remote Close, Local Data", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Receive FIN from remote
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||||
|
|
||||||
|
// We can still send data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||||
|
|
||||||
|
// Remote can still ACK our data
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Send our FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||||
|
|
||||||
|
// Receive final ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPAbnormalSequences(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Test handling of unsolicited RST in various states
|
||||||
|
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
|
||||||
|
// Send SYN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Receive unsolicited RST (without proper ACK)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
|
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
|
||||||
|
|
||||||
|
// Receive RST with proper ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||||
|
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
|
// Create tracker with a very short timeout for testing
|
||||||
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Connection Timeout", func(t *testing.T) {
|
||||||
|
// Establish a connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Wait for the connection to timeout
|
||||||
|
time.Sleep(2 * shortTimeout)
|
||||||
|
|
||||||
|
// Force cleanup
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
// Connection should be removed
|
||||||
|
_, exists := tracker.connections[key]
|
||||||
|
require.False(t, exists, "Connection should be removed after timeout")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Complete the connection close to enter TIME_WAIT
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// TIME_WAIT should have its own timeout value (usually 2*MSL)
|
||||||
|
// For the test, we're using a short timeout
|
||||||
|
time.Sleep(2 * shortTimeout)
|
||||||
|
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
// Connection should be removed
|
||||||
|
_, exists := tracker.connections[key]
|
||||||
|
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSynFlood(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
basePort := uint16(10000)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Create a large number of SYN packets to simulate a SYN flood
|
||||||
|
for i := uint16(0); i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we're tracking all connections
|
||||||
|
require.Equal(t, 1000, len(tracker.connections))
|
||||||
|
|
||||||
|
// Now simulate SYN timeout
|
||||||
|
var oldConns int
|
||||||
|
tracker.mutex.Lock()
|
||||||
|
for _, conn := range tracker.connections {
|
||||||
|
if conn.GetState() == TCPStateSynSent {
|
||||||
|
// Make the connection appear old
|
||||||
|
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
|
||||||
|
oldConns++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracker.mutex.Unlock()
|
||||||
|
require.Equal(t, 1000, oldConns)
|
||||||
|
|
||||||
|
// Run cleanup
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
// Check that stale connections were cleaned up
|
||||||
|
require.Equal(t, 0, len(tracker.connections))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
clientIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
serverIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
clientPort := uint16(12345)
|
||||||
|
serverPort := uint16(80)
|
||||||
|
|
||||||
|
// 1. Client sends SYN (we receive it as inbound)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: clientIP,
|
||||||
|
DstIP: serverIP,
|
||||||
|
SrcPort: clientPort,
|
||||||
|
DstPort: serverPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker.mutex.RLock()
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
tracker.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
|
||||||
|
|
||||||
|
// 2. Server sends SYN-ACK response
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
|
||||||
|
// 3. Client sends ACK to complete handshake
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||||
|
|
||||||
|
// 4. Test data transfer
|
||||||
|
// Client sends data
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||||
|
|
||||||
|
// Server sends ACK for data
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||||
|
|
||||||
|
// Server sends data
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||||
|
|
||||||
|
// Client sends ACK for data
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||||
|
|
||||||
|
// Verify state and counters
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
|
||||||
|
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
|
||||||
|
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
|
||||||
|
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
|
||||||
|
}
|
||||||
|
|
||||||
// Helper to establish a TCP connection
|
// Helper to establish a TCP connection
|
||||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
require.True(t, valid, "SYN-ACK should be allowed")
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkTCPTracker(b *testing.B) {
|
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
|
||||||
|
|
||||||
// Pre-populate some connections
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
|
||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
|
||||||
i := 0
|
|
||||||
for pb.Next() {
|
|
||||||
if i%2 == 0 {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
|
||||||
} else {
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Benchmark connection cleanup
|
|
||||||
func BenchmarkCleanup(b *testing.B) {
|
|
||||||
b.Run("TCPCleanup", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
|
||||||
for i := 0; i < 10000; i++ {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for connections to expire
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tracker.cleanup()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -19,6 +22,8 @@ const (
|
|||||||
// UDPConnTrack represents a UDP connection state
|
// UDPConnTrack represents a UDP connection state
|
||||||
type UDPConnTrack struct {
|
type UDPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// UDPTracker manages UDP connection states
|
// UDPTracker manages UDP connection states
|
||||||
@@ -29,11 +34,11 @@ type UDPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPTracker creates a new UDP connection tracker
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultUDPTimeout
|
timeout = DefaultUDPTimeout
|
||||||
}
|
}
|
||||||
@@ -46,7 +51,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine(ctx)
|
go tracker.cleanupRoutine(ctx)
|
||||||
@@ -54,55 +59,88 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound UDP connection
|
// TrackOutbound records an outbound UDP connection
|
||||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.mutex.Lock()
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &UDPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
SourcePort: srcPort,
|
|
||||||
DestPort: dstPort,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New UDP connection: %v", conn)
|
|
||||||
}
|
}
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
// TrackInbound records an inbound UDP connection
|
||||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &UDPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
|
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
conn.UpdateLastSeen()
|
||||||
return false
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
}
|
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
return true
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
|
||||||
conn.DestPort == srcPort &&
|
|
||||||
conn.SourcePort == dstPort
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupRoutine periodically removes stale connections
|
// cleanupRoutine periodically removes stale connections
|
||||||
@@ -125,11 +163,11 @@ func (t *UDPTracker) cleanup() {
|
|||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -139,29 +177,44 @@ func (t *UDPTracker) Close() {
|
|||||||
t.tickerCancel()
|
t.tickerCancel()
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnection safely retrieves a connection state
|
// GetConnection safely retrieves a connection state
|
||||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
defer t.mutex.RUnlock()
|
defer t.mutex.RUnlock()
|
||||||
|
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
conn, exists := t.connections[key]
|
SrcIP: srcIP,
|
||||||
if !exists {
|
DstIP: dstIP,
|
||||||
return nil, false
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
}
|
}
|
||||||
|
conn, exists := t.connections[key]
|
||||||
return conn, true
|
return conn, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timeout returns the configured timeout duration for the tracker
|
// Timeout returns the configured timeout duration for the tracker
|
||||||
func (t *UDPTracker) Timeout() time.Duration {
|
func (t *UDPTracker) Timeout() time.Duration {
|
||||||
return t.timeout
|
return t.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tracker := NewUDPTracker(tt.timeout, logger)
|
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
|
||||||
assert.NotNil(t, tracker)
|
assert.NotNil(t, tracker)
|
||||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
assert.NotNil(t, tracker.connections)
|
assert.NotNil(t, tracker.connections)
|
||||||
@@ -41,43 +41,48 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn, exists := tracker.connections[key]
|
conn, exists := tracker.connections[key]
|
||||||
require.True(t, exists)
|
require.True(t, exists)
|
||||||
assert.True(t, conn.SourceIP.Equal(srcIP))
|
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
|
||||||
assert.Equal(t, srcPort, conn.SourcePort)
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
assert.Equal(t, dstPort, conn.DestPort)
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(1*time.Second, logger)
|
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
// Track outbound connection
|
// Track outbound connection
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
sleep time.Duration
|
sleep time.Duration
|
||||||
@@ -94,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid source IP",
|
name: "invalid source IP",
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: srcIP,
|
dstIP: srcIP,
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
@@ -104,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid destination IP",
|
name: "invalid destination IP",
|
||||||
srcIP: dstIP,
|
srcIP: dstIP,
|
||||||
dstIP: net.ParseIP("192.168.1.4"),
|
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
sleep: 0,
|
sleep: 0,
|
||||||
@@ -144,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
if tt.sleep > 0 {
|
if tt.sleep > 0 {
|
||||||
time.Sleep(tt.sleep)
|
time.Sleep(tt.sleep)
|
||||||
}
|
}
|
||||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
|
||||||
assert.Equal(t, tt.want, got)
|
assert.Equal(t, tt.want, got)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -164,8 +169,8 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
tickerCancel: tickerCancel,
|
tickerCancel: tickerCancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
@@ -173,27 +178,27 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
// Add some connections
|
// Add some connections
|
||||||
connections := []struct {
|
connections := []struct {
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.2"),
|
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
dstIP: net.ParseIP("192.168.1.3"),
|
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||||
srcPort: 12345,
|
srcPort: 12345,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: net.ParseIP("192.168.1.5"),
|
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||||
srcPort: 12346,
|
srcPort: 12346,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, conn := range connections {
|
for _, conn := range connections {
|
||||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify initial connections
|
// Verify initial connections
|
||||||
@@ -215,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkUDPTracker(b *testing.B) {
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
@@ -79,3 +81,10 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
|||||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type epID stack.TransportEndpointID
|
||||||
|
|
||||||
|
func (i epID) String() string {
|
||||||
|
// src and remote is swapped
|
||||||
|
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -29,6 +30,7 @@ const (
|
|||||||
|
|
||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
@@ -38,7 +40,7 @@ type Forwarder struct {
|
|||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
@@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &Forwarder{
|
f := &Forwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
stack: s,
|
stack: s,
|
||||||
endpoint: endpoint,
|
endpoint: endpoint,
|
||||||
udpForwarder: newUDPForwarder(mtu, logger),
|
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
|
|||||||
@@ -3,14 +3,30 @@ package forwarder
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleICMP handles ICMP packets from the network stack
|
// handleICMP handles ICMP packets from the network stack
|
||||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||||
|
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||||
|
icmpType := uint8(icmpHdr.Type())
|
||||||
|
icmpCode := uint8(icmpHdr.Code())
|
||||||
|
|
||||||
|
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||||
|
// dont process our own replies
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -18,7 +34,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
// TODO: support non-root
|
// TODO: support non-root
|
||||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||||
|
|
||||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
return false
|
return false
|
||||||
@@ -32,47 +48,31 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
dst := &net.IPAddr{IP: dstIP}
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
// Get the complete ICMP message (header + data)
|
|
||||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||||
payload := fullPacket.AsSlice()
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||||
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
// For Echo Requests, send and handle response
|
// For Echo Requests, send and handle response
|
||||||
switch icmpHdr.Type() {
|
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||||
case header.ICMPv4Echo:
|
f.handleEchoResponse(icmpHdr, conn, id)
|
||||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||||
case header.ICMPv4EchoReply:
|
|
||||||
// dont process our own replies
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||||
_, err = conn.WriteTo(payload, dst)
|
|
||||||
if err != nil {
|
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
|
||||||
id, icmpHdr.Type(), icmpHdr.Code())
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
|
||||||
id, icmpHdr.Type(), icmpHdr.Code())
|
|
||||||
|
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||||
return true
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, f.endpoint.mtu)
|
response := make([]byte, f.endpoint.mtu)
|
||||||
@@ -81,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
|||||||
if !isTimeout(err) {
|
if !isTimeout(err) {
|
||||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||||
}
|
}
|
||||||
return true
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
@@ -101,9 +101,27 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
|||||||
|
|
||||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||||
return true
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||||
return true
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendICMPEvent stores flow events for ICMP packets
|
||||||
|
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
||||||
|
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
|
||||||
|
// TODO: get packets/bytes
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,24 +5,38 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleTCP is called by the TCP forwarder for new connections.
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
id := r.ID()
|
id := r.ID()
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.Complete(true)
|
r.Complete(true)
|
||||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,12 +58,13 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
|
|
||||||
inConn := gonet.NewTCPConn(&wq, ep)
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
|
||||||
|
|
||||||
go f.proxyTCP(id, inConn, outConn, ep)
|
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
@@ -58,6 +73,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
}
|
}
|
||||||
ep.Close()
|
ep.Close()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Create context for managing the proxy goroutines
|
// Create context for managing the proxy goroutines
|
||||||
@@ -78,13 +95,38 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
||||||
return
|
return
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil && !isClosedError(err) {
|
if err != nil && !isClosedError(err) {
|
||||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
@@ -16,6 +18,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -28,15 +31,17 @@ type udpPacketConn struct {
|
|||||||
lastSeen atomic.Int64
|
lastSeen atomic.Int64
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ep tcpip.Endpoint
|
ep tcpip.Endpoint
|
||||||
|
flowID uuid.UUID
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpForwarder struct {
|
type udpForwarder struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
flowLogger nftypes.FlowLogger
|
||||||
bufPool sync.Pool
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
ctx context.Context
|
bufPool sync.Pool
|
||||||
cancel context.CancelFunc
|
ctx context.Context
|
||||||
|
cancel context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
type idleConn struct {
|
type idleConn struct {
|
||||||
@@ -44,13 +49,14 @@ type idleConn struct {
|
|||||||
conn *udpPacketConn
|
conn *udpPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &udpForwarder{
|
f := &udpForwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
flowLogger: flowLogger,
|
||||||
ctx: ctx,
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
cancel: cancel,
|
ctx: ctx,
|
||||||
|
cancel: cancel,
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
b := make([]byte, mtu)
|
b := make([]byte, mtu)
|
||||||
@@ -72,10 +78,10 @@ func (f *udpForwarder) Stop() {
|
|||||||
for id, conn := range f.conns {
|
for id, conn := range f.conns {
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
if err := conn.conn.Close(); err != nil {
|
if err := conn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := conn.outConn.Close(); err != nil {
|
if err := conn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.ep.Close()
|
conn.ep.Close()
|
||||||
@@ -106,10 +112,10 @@ func (f *udpForwarder) cleanup() {
|
|||||||
for _, idle := range idleConns {
|
for _, idle := range idleConns {
|
||||||
idle.conn.cancel()
|
idle.conn.cancel()
|
||||||
if err := idle.conn.conn.Close(); err != nil {
|
if err := idle.conn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
|
||||||
}
|
}
|
||||||
if err := idle.conn.outConn.Close(); err != nil {
|
if err := idle.conn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idle.conn.ep.Close()
|
idle.conn.ep.Close()
|
||||||
@@ -118,7 +124,7 @@ func (f *udpForwarder) cleanup() {
|
|||||||
delete(f.conns, idle.id)
|
delete(f.conns, idle.id)
|
||||||
f.Unlock()
|
f.Unlock()
|
||||||
|
|
||||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,14 +143,24 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
_, exists := f.udpForwarder.conns[id]
|
_, exists := f.udpForwarder.conns[id]
|
||||||
f.udpForwarder.RUnlock()
|
f.udpForwarder.RUnlock()
|
||||||
if exists {
|
if exists {
|
||||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||||
// TODO: Send ICMP error message
|
// TODO: Send ICMP error message
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -155,7 +171,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
if epErr != nil {
|
if epErr != nil {
|
||||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -168,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
outConn: outConn,
|
outConn: outConn,
|
||||||
cancel: connCancel,
|
cancel: connCancel,
|
||||||
ep: ep,
|
ep: ep,
|
||||||
|
flowID: flowID,
|
||||||
}
|
}
|
||||||
pConn.updateLastSeen()
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
@@ -177,17 +194,20 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.udpForwarder.conns[id] = pConn
|
f.udpForwarder.conns[id] = pConn
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
|
||||||
|
|
||||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,10 +215,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
defer func() {
|
defer func() {
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := pConn.conn.Close(); err != nil {
|
if err := pConn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := pConn.outConn.Close(); err != nil {
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ep.Close()
|
ep.Close()
|
||||||
@@ -206,6 +226,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
f.udpForwarder.Lock()
|
f.udpForwarder.Lock()
|
||||||
delete(f.udpForwarder.conns, id)
|
delete(f.udpForwarder.conns, id)
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
@@ -220,17 +242,43 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
||||||
return
|
return
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil && !isClosedError(err) {
|
if err != nil && !isClosedError(err) {
|
||||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *udpPacketConn) updateLastSeen() {
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
c.lastSeen.Store(time.Now().UnixNano())
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -13,8 +14,13 @@ import (
|
|||||||
type localIPManager struct {
|
type localIPManager struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|
||||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
// fixed-size high array for upper byte of a IPv4 address
|
||||||
ipv4Bitmap [1 << 16]uint32
|
ipv4Bitmap [256]*ipv4LowBitmap
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||||
|
type ipv4LowBitmap struct {
|
||||||
|
bitmap [8192]uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLocalIPManager() *localIPManager {
|
func newLocalIPManager() *localIPManager {
|
||||||
@@ -26,39 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
if ipv4 == nil {
|
if ipv4 == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
high := uint16(ipv4[0])
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
index := low / 32
|
||||||
ipv4 := ip.To4()
|
bit := low % 32
|
||||||
if ipv4 == nil {
|
|
||||||
return false
|
if m.ipv4Bitmap[high] == nil {
|
||||||
|
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||||
}
|
}
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
high := uint16(ipv4[0])
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
if int(high) >= len(*newIPv4Bitmap) {
|
|
||||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
if bitmap[high] == nil {
|
||||||
|
bitmap[high] = &ipv4LowBitmap{}
|
||||||
}
|
}
|
||||||
ipStr := ip.String()
|
|
||||||
|
index := low / 32
|
||||||
|
bit := low % 32
|
||||||
|
bitmap[high].bitmap[index] |= 1 << bit
|
||||||
|
|
||||||
|
ipStr := ipv4.String()
|
||||||
if _, exists := ipv4Set[ipStr]; !exists {
|
if _, exists := ipv4Set[ipStr]; !exists {
|
||||||
ipv4Set[ipStr] = struct{}{}
|
ipv4Set[ipStr] = struct{}{}
|
||||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||||
|
high := uint16(ip[0])
|
||||||
|
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||||
|
|
||||||
|
if m.ipv4Bitmap[high] == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
index := low / 32
|
||||||
|
bit := low % 32
|
||||||
|
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||||
|
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
@@ -76,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
log.Debugf("process IP failed: %v", err)
|
log.Debugf("process IP failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -89,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var newIPv4Bitmap [1 << 16]uint32
|
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||||
ipv4Set := make(map[string]struct{})
|
ipv4Set := make(map[string]struct{})
|
||||||
var ipv4Addresses []string
|
var ipv4Addresses []string
|
||||||
|
|
||||||
// 127.0.0.0/8
|
// 127.0.0.0/8
|
||||||
high := uint16(127) << 8
|
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||||
for i := uint16(0); i < 256; i++ {
|
for i := 0; i < 8192; i++ {
|
||||||
newIPv4Bitmap[high|i] = 0xffffffff
|
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||||
}
|
}
|
||||||
|
|
||||||
if iface != nil {
|
if iface != nil {
|
||||||
@@ -122,13 +148,13 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||||
|
if !ip.Is4() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
return m.checkBitmapBit(ip.AsSlice())
|
||||||
return m.checkBitmapBit(ipv4)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -13,7 +14,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupAddr wgaddr.Address
|
setupAddr wgaddr.Address
|
||||||
testIP net.IP
|
testIP netip.Addr
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -25,7 +26,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -37,7 +38,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -49,7 +50,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -61,7 +62,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -73,7 +74,19 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP doesn't match - addresses 32 apart",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -85,7 +98,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(64, 128),
|
Mask: net.CIDRMask(64, 128),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -174,7 +187,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
|||||||
t.Logf("Testing %d IPs", len(tests))
|
t.Logf("Testing %d IPs", len(tests))
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.ip, func(t *testing.T) {
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -191,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) {
|
|||||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup bitmap version
|
// Setup bitmap
|
||||||
bitmapManager := &localIPManager{
|
bitmapManager := newLocalIPManager()
|
||||||
ipv4Bitmap: [1 << 16]uint32{},
|
|
||||||
}
|
|
||||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||||
bitmapManager.setBitmapBit(ip)
|
bitmapManager.setBitmapBit(ip)
|
||||||
}
|
}
|
||||||
@@ -247,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
|||||||
|
|
||||||
// Create two managers - one checks WG IP first, other checks it last
|
// Create two managers - one checks WG IP first, other checks it last
|
||||||
b.Run("WG_First", func(b *testing.B) {
|
b.Run("WG_First", func(b *testing.B) {
|
||||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
bm := newLocalIPManager()
|
||||||
bm.setBitmapBit(wgIP)
|
bm.setBitmapBit(wgIP)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
@@ -256,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("WG_Last", func(b *testing.B) {
|
b.Run("WG_Last", func(b *testing.B) {
|
||||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
bm := newLocalIPManager()
|
||||||
// Fill with other IPs first
|
// Fill with other IPs first
|
||||||
for i := 0; i < 15; i++ {
|
for i := 0; i < 15; i++ {
|
||||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
// Package log provides a high-performance, non-blocking logger for userspace networking
|
||||||
package log
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -13,13 +13,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2 // 2KB per message
|
maxMessageSize = 1024 * 2
|
||||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
|
logChannelSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
// Level represents log severity
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
|
|||||||
LevelTrace: "TRAC",
|
LevelTrace: "TRAC",
|
||||||
}
|
}
|
||||||
|
|
||||||
// Logger is a high-performance, non-blocking logger
|
type logMessage struct {
|
||||||
type Logger struct {
|
level Level
|
||||||
output io.Writer
|
format string
|
||||||
level atomic.Uint32
|
args []any
|
||||||
buffer *ringBuffer
|
|
||||||
shutdown chan struct{}
|
|
||||||
closeOnce sync.Once
|
|
||||||
wg sync.WaitGroup
|
|
||||||
|
|
||||||
// Reusable buffer pool for formatting messages
|
|
||||||
bufPool sync.Pool
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Logger is a high-performance, non-blocking logger
|
||||||
|
type Logger struct {
|
||||||
|
output io.Writer
|
||||||
|
level atomic.Uint32
|
||||||
|
msgChannel chan logMessage
|
||||||
|
shutdown chan struct{}
|
||||||
|
closeOnce sync.Once
|
||||||
|
wg sync.WaitGroup
|
||||||
|
bufPool sync.Pool
|
||||||
|
}
|
||||||
|
|
||||||
|
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
|
||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
buffer: newRingBuffer(bufferSize),
|
msgChannel: make(chan logMessage, logChannelSize),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() any {
|
||||||
// Pre-allocate buffer for message formatting
|
|
||||||
b := make([]byte, 0, maxMessageSize)
|
b := make([]byte, 0, maxMessageSize)
|
||||||
return &b
|
return &b
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logrusLevel := logrusLogger.GetLevel()
|
logrusLevel := logrusLogger.GetLevel()
|
||||||
l.level.Store(uint32(logrusLevel))
|
l.level.Store(uint32(logrusLevel))
|
||||||
level := levelStrings[Level(logrusLevel)]
|
level := levelStrings[Level(logrusLevel)]
|
||||||
@@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLevel sets the logging level
|
||||||
func (l *Logger) SetLevel(level Level) {
|
func (l *Logger) SetLevel(level Level) {
|
||||||
l.level.Store(uint32(level))
|
l.level.Store(uint32(level))
|
||||||
|
|
||||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
func (l *Logger) log(level Level, format string, args ...any) {
|
||||||
*buf = (*buf)[:0]
|
select {
|
||||||
|
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
|
||||||
// Timestamp
|
default:
|
||||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Level
|
|
||||||
*buf = append(*buf, levelStrings[level]...)
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Message
|
|
||||||
if len(args) > 0 {
|
|
||||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
|
||||||
} else {
|
|
||||||
*buf = append(*buf, format...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*buf = append(*buf, '\n')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
// Error logs a message at error level
|
||||||
bufp := l.bufPool.Get().(*[]byte)
|
func (l *Logger) Error(format string, args ...any) {
|
||||||
l.formatMessage(bufp, level, format, args...)
|
|
||||||
|
|
||||||
if len(*bufp) > maxMessageSize {
|
|
||||||
*bufp = (*bufp)[:maxMessageSize]
|
|
||||||
}
|
|
||||||
_, _ = l.buffer.Write(*bufp)
|
|
||||||
|
|
||||||
l.bufPool.Put(bufp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Error(format string, args ...interface{}) {
|
|
||||||
if l.level.Load() >= uint32(LevelError) {
|
if l.level.Load() >= uint32(LevelError) {
|
||||||
l.log(LevelError, format, args...)
|
l.log(LevelError, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
// Warn logs a message at warning level
|
||||||
|
func (l *Logger) Warn(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelWarn) {
|
if l.level.Load() >= uint32(LevelWarn) {
|
||||||
l.log(LevelWarn, format, args...)
|
l.log(LevelWarn, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Info(format string, args ...interface{}) {
|
// Info logs a message at info level
|
||||||
|
func (l *Logger) Info(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelInfo) {
|
if l.level.Load() >= uint32(LevelInfo) {
|
||||||
l.log(LevelInfo, format, args...)
|
l.log(LevelInfo, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
// Debug logs a message at debug level
|
||||||
|
func (l *Logger) Debug(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelDebug) {
|
if l.level.Load() >= uint32(LevelDebug) {
|
||||||
l.log(LevelDebug, format, args...)
|
l.log(LevelDebug, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
// Trace logs a message at trace level
|
||||||
|
func (l *Logger) Trace(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelTrace) {
|
if l.level.Load() >= uint32(LevelTrace) {
|
||||||
l.log(LevelTrace, format, args...)
|
l.log(LevelTrace, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// worker periodically flushes the buffer
|
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
*buf = append(*buf, levelStrings[level]...)
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
var msg string
|
||||||
|
if len(args) > 0 {
|
||||||
|
msg = fmt.Sprintf(format, args...)
|
||||||
|
} else {
|
||||||
|
msg = format
|
||||||
|
}
|
||||||
|
*buf = append(*buf, msg...)
|
||||||
|
*buf = append(*buf, '\n')
|
||||||
|
|
||||||
|
if len(*buf) > maxMessageSize {
|
||||||
|
*buf = (*buf)[:maxMessageSize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processMessage handles a single log message and adds it to the buffer
|
||||||
|
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
|
||||||
|
bufp := l.bufPool.Get().(*[]byte)
|
||||||
|
defer l.bufPool.Put(bufp)
|
||||||
|
|
||||||
|
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
|
||||||
|
|
||||||
|
if len(*buffer)+len(*bufp) > maxBatchSize {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
*buffer = append(*buffer, *bufp...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushBuffer writes the accumulated buffer to output
|
||||||
|
func (l *Logger) flushBuffer(buffer *[]byte) {
|
||||||
|
if len(*buffer) > 0 {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processBatch processes as many messages as possible without blocking
|
||||||
|
func (l *Logger) processBatch(buffer *[]byte) {
|
||||||
|
for len(*buffer) < maxBatchSize {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleShutdown manages the graceful shutdown sequence with timeout
|
||||||
|
func (l *Logger) handleShutdown(buffer *[]byte) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
case <-ctx.Done():
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(l.msgChannel) == 0 {
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker is the main goroutine that processes log messages
|
||||||
func (l *Logger) worker() {
|
func (l *Logger) worker() {
|
||||||
defer l.wg.Done()
|
defer l.wg.Done()
|
||||||
|
|
||||||
ticker := time.NewTicker(defaultFlushInterval)
|
ticker := time.NewTicker(defaultFlushInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
buf := make([]byte, 0, maxBatchSize)
|
buffer := make([]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-l.shutdown:
|
case <-l.shutdown:
|
||||||
|
l.handleShutdown(&buffer)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Read accumulated messages
|
l.flushBuffer(&buffer)
|
||||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
case msg := <-l.msgChannel:
|
||||||
if n == 0 {
|
l.processMessage(msg, &buffer)
|
||||||
continue
|
l.processBatch(&buffer)
|
||||||
}
|
|
||||||
|
|
||||||
// Write batch
|
|
||||||
_, _ = l.output.Write(buf[:n])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
121
client/firewall/uspfilter/log/log_test.go
Normal file
121
client/firewall/uspfilter/log/log_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package log_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type discard struct{}
|
||||||
|
|
||||||
|
func (d *discard) Write(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger(b *testing.B) {
|
||||||
|
simpleMessage := "Connection established"
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4 // TCPStateEstablished
|
||||||
|
|
||||||
|
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
|
||||||
|
protocol := "TCP"
|
||||||
|
direction := "outbound"
|
||||||
|
flags := uint16(0x18) // ACK + PSH
|
||||||
|
sequence := uint32(123456789)
|
||||||
|
acknowledged := uint32(987654321)
|
||||||
|
payloadSize := 1460
|
||||||
|
fragmented := false
|
||||||
|
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
|
||||||
|
|
||||||
|
b.Run("SimpleMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(simpleMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConntrackMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ComplexMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerParallel tests the logger under concurrent load
|
||||||
|
func BenchmarkLoggerParallel(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
|
||||||
|
func BenchmarkLoggerBurst(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestLogger() *log.Logger {
|
||||||
|
logrusLogger := logrus.New()
|
||||||
|
logrusLogger.SetOutput(&discard{})
|
||||||
|
logrusLogger.SetLevel(logrus.TraceLevel)
|
||||||
|
return log.NewFromLogrus(logrusLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupLogger(logger *log.Logger) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = logger.Stop(ctx)
|
||||||
|
}
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
package log
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
// ringBuffer is a simple ring buffer implementation
|
|
||||||
type ringBuffer struct {
|
|
||||||
buf []byte
|
|
||||||
size int
|
|
||||||
r, w int64 // Read and write positions
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRingBuffer(size int) *ringBuffer {
|
|
||||||
return &ringBuffer{
|
|
||||||
buf: make([]byte, size),
|
|
||||||
size: size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
|
||||||
if len(p) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if len(p) > r.size {
|
|
||||||
p = p[:r.size]
|
|
||||||
}
|
|
||||||
|
|
||||||
n = len(p)
|
|
||||||
|
|
||||||
// Write data, handling wrap-around
|
|
||||||
pos := int(r.w % int64(r.size))
|
|
||||||
writeLen := min(len(p), r.size-pos)
|
|
||||||
copy(r.buf[pos:], p[:writeLen])
|
|
||||||
|
|
||||||
// If we have more data and need to wrap around
|
|
||||||
if writeLen < len(p) {
|
|
||||||
copy(r.buf, p[writeLen:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update write position
|
|
||||||
r.w += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if r.w == r.r {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate available data accounting for wraparound
|
|
||||||
available := int(r.w - r.r)
|
|
||||||
if available < 0 {
|
|
||||||
available += r.size
|
|
||||||
}
|
|
||||||
available = min(available, r.size)
|
|
||||||
|
|
||||||
// Limit read to buffer size
|
|
||||||
toRead := min(available, len(p))
|
|
||||||
if toRead == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read data, handling wrap-around
|
|
||||||
pos := int(r.r % int64(r.size))
|
|
||||||
readLen := min(toRead, r.size-pos)
|
|
||||||
n = copy(p, r.buf[pos:pos+readLen])
|
|
||||||
|
|
||||||
// If we need more data and need to wrap around
|
|
||||||
if readLen < toRead {
|
|
||||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update read position
|
|
||||||
r.r += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -12,14 +11,14 @@ import (
|
|||||||
// PeerRule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type PeerRule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
ip net.IP
|
mgmtId []byte
|
||||||
|
ip netip.Addr
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
matchByIP bool
|
matchByIP bool
|
||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
sPort *firewall.Port
|
sPort *firewall.Port
|
||||||
dPort *firewall.Port
|
dPort *firewall.Port
|
||||||
drop bool
|
drop bool
|
||||||
comment string
|
|
||||||
|
|
||||||
udpHook func([]byte) bool
|
udpHook func([]byte) bool
|
||||||
}
|
}
|
||||||
@@ -30,13 +29,15 @@ func (r *PeerRule) ID() string {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type RouteRule struct {
|
type RouteRule struct {
|
||||||
id string
|
id string
|
||||||
sources []netip.Prefix
|
mgmtId []byte
|
||||||
destination netip.Prefix
|
sources []netip.Prefix
|
||||||
proto firewall.Protocol
|
dstSet firewall.Set
|
||||||
srcPort *firewall.Port
|
destinations []netip.Prefix
|
||||||
dstPort *firewall.Port
|
proto firewall.Protocol
|
||||||
action firewall.Action
|
srcPort *firewall.Port
|
||||||
|
dstPort *firewall.Port
|
||||||
|
action firewall.Action
|
||||||
}
|
}
|
||||||
|
|
||||||
// ID returns the rule id
|
// ID returns the rule id
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -53,8 +53,8 @@ type TraceResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketTrace struct {
|
type PacketTrace struct {
|
||||||
SourceIP net.IP
|
SourceIP netip.Addr
|
||||||
DestinationIP net.IP
|
DestinationIP netip.Addr
|
||||||
Protocol string
|
Protocol string
|
||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestinationPort uint16
|
DestinationPort uint16
|
||||||
@@ -72,8 +72,8 @@ type TCPState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketBuilder struct {
|
type PacketBuilder struct {
|
||||||
SrcIP net.IP
|
SrcIP netip.Addr
|
||||||
DstIP net.IP
|
DstIP netip.Addr
|
||||||
Protocol fw.Protocol
|
Protocol fw.Protocol
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
|||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
SrcIP: p.SrcIP,
|
SrcIP: p.SrcIP.AsSlice(),
|
||||||
DstIP: p.DstIP,
|
DstIP: p.DstIP.AsSlice(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
return trace
|
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||||
|
return trace
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !m.handleRouting(trace) {
|
if !m.handleRouting(trace) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.nativeRouter {
|
if m.nativeRouter.Load() {
|
||||||
return m.handleNativeRouter(trace)
|
return m.handleNativeRouter(trace)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
|
||||||
msg := "No existing connection found"
|
msg := "No existing connection found"
|
||||||
if allowed {
|
if allowed {
|
||||||
msg = m.buildConntrackStateMessage(d)
|
msg = m.buildConntrackStateMessage(d)
|
||||||
@@ -309,32 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
|||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
if !m.localForwarding {
|
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
|
|
||||||
|
strRuleId := "<no id>"
|
||||||
|
if ruleId != nil {
|
||||||
|
strRuleId = string(ruleId)
|
||||||
|
}
|
||||||
|
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
|
||||||
|
if blocked {
|
||||||
|
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
|
||||||
|
trace.AddResult(StagePeerACL, msg, false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
trace.AddResult(StagePeerACL, msg, true)
|
||||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
|
||||||
|
|
||||||
msg := "Allowed by peer ACL rules"
|
|
||||||
if blocked {
|
|
||||||
msg = "Blocked by peer ACL rules"
|
|
||||||
}
|
|
||||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
|
||||||
|
|
||||||
|
// Handle netstack mode
|
||||||
if m.netstack {
|
if m.netstack {
|
||||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
switch {
|
||||||
|
case !m.localForwarding:
|
||||||
|
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
|
||||||
|
case m.forwarder.Load() != nil:
|
||||||
|
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
default:
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
// In normal mode, packets are allowed through for local delivery
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled.Load() {
|
||||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||||
return false
|
return false
|
||||||
@@ -350,18 +366,23 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
|||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||||
proto := getProtocolFromPacket(d)
|
proto, _ := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
|
||||||
msg := "Allowed by route ACLs"
|
strId := string(id)
|
||||||
|
if id == nil {
|
||||||
|
strId = "<no id>"
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
|
||||||
if !allowed {
|
if !allowed {
|
||||||
msg = "Blocked by route ACLs"
|
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
|
||||||
}
|
}
|
||||||
trace.AddResult(StageRouteACL, msg, allowed)
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
if allowed && m.forwarder != nil {
|
if allowed && m.forwarder.Load() != nil {
|
||||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -380,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
|||||||
|
|
||||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
// will create or update the connection state
|
// will create or update the connection state
|
||||||
dropped := m.processOutgoingHooks(packetData)
|
dropped := m.processOutgoingHooks(packetData, 0)
|
||||||
if dropped {
|
if dropped {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
440
client/firewall/uspfilter/tracer_test.go
Normal file
440
client/firewall/uspfilter/tracer_test.go
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
|
||||||
|
t.Logf("Trace results: %v", trace.Results)
|
||||||
|
actualStages := make([]PacketStage, 0, len(trace.Results))
|
||||||
|
for _, result := range trace.Results {
|
||||||
|
actualStages = append(actualStages, result.Stage)
|
||||||
|
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
|
||||||
|
require.NotEmpty(t, trace.Results, "Trace should have results")
|
||||||
|
lastResult := trace.Results[len(trace.Results)-1]
|
||||||
|
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
|
||||||
|
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTracePacket(t *testing.T) {
|
||||||
|
setupTracerTest := func(statefulMode bool) *Manager {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if !statefulMode {
|
||||||
|
m.stateful = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
builder := &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: protocol,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
|
||||||
|
if protocol == "tcp" {
|
||||||
|
builder.TCPState = &TCPState{SYN: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder
|
||||||
|
}
|
||||||
|
|
||||||
|
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
return &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: "icmp",
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setup func(*Manager)
|
||||||
|
packetBuilder func() *PacketBuilder
|
||||||
|
expectedStages []PacketStage
|
||||||
|
expectedAllow bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = true
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithoutForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{192, 168, 17, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, fw.Network{Prefix: dst}, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_NativeRouter",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(true)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_RoutingDisabled",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(false)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "192.168.17.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ConnectionTracking_Hit",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.100")
|
||||||
|
dstIP := netip.MustParseAddr("1.1.1.1")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
|
||||||
|
pb.TCPState = &TCPState{SYN: true, ACK: true}
|
||||||
|
return pb
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OutboundTraffic",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPEchoRequest",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPDestinationUnreachable",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithoutHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolUDP
|
||||||
|
port := &fw.Port{Values: []uint16{53}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
hookFunc := func([]byte) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "StatefulDisabled_NoTracking",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.stateful = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
m := setupTracerTest(true)
|
||||||
|
|
||||||
|
tc.setup(m)
|
||||||
|
|
||||||
|
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||||
|
"100.10.0.100 should be recognized as a local IP")
|
||||||
|
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("192.168.17.2")),
|
||||||
|
"192.168.17.2 should not be recognized as a local IP")
|
||||||
|
|
||||||
|
pb := tc.packetBuilder()
|
||||||
|
|
||||||
|
trace, err := m.TracePacketFromBuilder(pb)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
verifyTraceStages(t, trace, tc.expectedStages)
|
||||||
|
verifyFinalDisposition(t, trace, tc.expectedAllow)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
File diff suppressed because it is too large
Load Diff
@@ -93,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: false,
|
stateful: false,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Single rule allowing all traffic
|
// Single rule allowing all traffic
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||||
fw.ActionAccept, "", "allow all")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||||
@@ -114,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Add explicit rules matching return traffic pattern
|
// Add explicit rules matching return traffic pattern
|
||||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||||
ip := generateRandomIPs(1)[0]
|
ip := generateRandomIPs(1)[0]
|
||||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
_, err := m.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
|
ip,
|
||||||
|
fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||||
&fw.Port{Values: []uint16{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
fw.ActionAccept, "", "explicit return")
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -128,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: true,
|
stateful: true,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Add some basic rules but rely on state for established connections
|
// Add some basic rules but rely on state for established connections
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
_, err := m.AddPeerFiltering(
|
||||||
fw.ActionDrop, "", "default drop")
|
nil,
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionDrop,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Connection tracking with established connections",
|
desc: "Connection tracking with established connections",
|
||||||
@@ -158,7 +169,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -182,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
|
|
||||||
// For stateful scenarios, establish the connection
|
// For stateful scenarios, establish the connection
|
||||||
if sc.stateful {
|
if sc.stateful {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -203,7 +214,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -219,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test packet
|
// Test packet
|
||||||
@@ -227,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
// First establish our test connection
|
// First establish our test connection
|
||||||
manager.processOutgoingHooks(testOut)
|
manager.processOutgoingHooks(testOut, 0)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn)
|
manager.dropFilter(testIn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -251,7 +262,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -267,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
if sc.established {
|
if sc.established {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -450,7 +461,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -466,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
// For stateful cases and established connections
|
// For stateful cases and established connections
|
||||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
|
|
||||||
// For TCP post-handshake, simulate full handshake
|
// For TCP post-handshake, simulate full handshake
|
||||||
if sc.state == "post_handshake" {
|
if sc.state == "post_handshake" {
|
||||||
// SYN
|
// SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -577,7 +588,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -590,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -616,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Initial SYN
|
// Initial SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare test packets simulating bidirectional traffic
|
// Prepare test packets simulating bidirectional traffic
|
||||||
@@ -647,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx])
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -668,7 +676,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -681,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -756,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.dropFilter(p.response)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -787,7 +792,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -799,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -824,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
for i := 0; i < sc.connCount; i++ {
|
for i := 0; i < sc.connCount; i++ {
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-generate test packets
|
// Pre-generate test packets
|
||||||
@@ -854,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
counter++
|
counter++
|
||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
manager.dropFilter(inPackets[connIdx])
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -875,7 +877,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -886,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -951,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.dropFilter(p.response)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -1033,14 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
_, err := manager.AddRouteFiltering(
|
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
r.sources,
|
|
||||||
r.dest,
|
|
||||||
r.proto,
|
|
||||||
nil,
|
|
||||||
r.port,
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1062,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/mocks"
|
"github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestPeerACLFiltering(t *testing.T) {
|
func TestPeerACLFiltering(t *testing.T) {
|
||||||
@@ -34,7 +35,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, manager)
|
require.NotNil(t, manager)
|
||||||
|
|
||||||
@@ -188,24 +189,321 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
ruleAction: fw.ActionAccept,
|
ruleAction: fw.ActionAccept,
|
||||||
shouldBeBlocked: true,
|
shouldBeBlocked: true,
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
name: "Allow TCP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow UDP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP packet doesn't match UDP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP packet doesn't match TCP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match TCP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match UDP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow TCP traffic within port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Block TCP traffic outside port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 7999,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Edge Case - Port at Range Boundary",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8100,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP Port Range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 5060,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{5060, 5070}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow multiple destination ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow multiple source ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||||
|
ruleAction: fw.ActionAccept,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
// New drop test cases
|
||||||
|
{
|
||||||
|
name: "Drop TCP traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop UDP traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolUDP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{53}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop ICMP traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolICMP,
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop all traffic from WG peer",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolALL,
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop traffic from multiple source ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleSrcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop multiple destination ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop TCP traffic within port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 8080,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Accept TCP traffic outside drop port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 7999,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Drop TCP traffic with source port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 32100,
|
||||||
|
dstPort: 80,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleSrcPort: &fw.Port{IsRange: true, Values: []uint16{32000, 33000}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Mixed rule - drop specific port but allow other ports",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "100.10.0.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
ruleIP: "100.10.0.1",
|
||||||
|
ruleProto: fw.ProtocolTCP,
|
||||||
|
ruleDstPort: &fw.Port{Values: []uint16{443}},
|
||||||
|
ruleAction: fw.ActionDrop,
|
||||||
|
shouldBeBlocked: true,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||||
isDropped := manager.DropIncoming(packet)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
|
||||||
|
if tc.ruleAction == fw.ActionDrop {
|
||||||
|
// add general accept rule to test drop rule
|
||||||
|
// TODO: this only works because 0.0.0.0 is tested last, we need to implement order
|
||||||
|
rules, err := manager.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
fw.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotEmpty(t, rules)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
for _, rule := range rules {
|
||||||
|
require.NoError(t, manager.DeletePeerRule(rule))
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
rules, err := manager.AddPeerFiltering(
|
rules, err := manager.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.ParseIP(tc.ruleIP),
|
net.ParseIP(tc.ruleIP),
|
||||||
tc.ruleProto,
|
tc.ruleProto,
|
||||||
tc.ruleSrcPort,
|
tc.ruleSrcPort,
|
||||||
tc.ruleDstPort,
|
tc.ruleDstPort,
|
||||||
tc.ruleAction,
|
tc.ruleAction,
|
||||||
"",
|
"",
|
||||||
tc.name,
|
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules)
|
require.NotEmpty(t, rules)
|
||||||
@@ -217,7 +515,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
isDropped := manager.DropIncoming(packet)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -302,12 +600,12 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
require.True(tb, manager.routingEnabled)
|
require.True(tb, manager.routingEnabled.Load())
|
||||||
require.False(tb, manager.nativeRouter)
|
require.False(tb, manager.nativeRouter.Load())
|
||||||
|
|
||||||
tb.Cleanup(func() {
|
tb.Cleanup(func() {
|
||||||
require.NoError(tb, manager.Close(nil))
|
require.NoError(tb, manager.Close(nil))
|
||||||
@@ -321,7 +619,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
type rule struct {
|
type rule struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -347,7 +645,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -363,7 +661,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -379,7 +677,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -395,7 +693,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolUDP,
|
proto: fw.ProtocolUDP,
|
||||||
dstPort: &fw.Port{Values: []uint16{53}},
|
dstPort: &fw.Port{Values: []uint16{53}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -409,7 +707,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
@@ -424,7 +722,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -440,7 +738,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -456,7 +754,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -472,7 +770,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -488,7 +786,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -507,7 +805,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
netip.MustParsePrefix("100.10.0.0/16"),
|
netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
},
|
},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -521,7 +819,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
proto: fw.ProtocolICMP,
|
proto: fw.ProtocolICMP,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
@@ -536,33 +834,13 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
shouldPass: true,
|
shouldPass: true,
|
||||||
},
|
},
|
||||||
{
|
|
||||||
name: "Multiple source networks with mismatched protocol",
|
|
||||||
srcIP: "172.16.0.1",
|
|
||||||
dstIP: "192.168.1.100",
|
|
||||||
// Should not match TCP rule
|
|
||||||
proto: fw.ProtocolUDP,
|
|
||||||
srcPort: 12345,
|
|
||||||
dstPort: 80,
|
|
||||||
rule: rule{
|
|
||||||
sources: []netip.Prefix{
|
|
||||||
netip.MustParsePrefix("100.10.0.0/16"),
|
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
|
||||||
},
|
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
|
||||||
proto: fw.ProtocolTCP,
|
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
|
||||||
action: fw.ActionAccept,
|
|
||||||
},
|
|
||||||
shouldPass: false,
|
|
||||||
},
|
|
||||||
{
|
{
|
||||||
name: "Allow multiple destination ports",
|
name: "Allow multiple destination ports",
|
||||||
srcIP: "100.10.0.1",
|
srcIP: "100.10.0.1",
|
||||||
@@ -572,7 +850,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
dstPort: &fw.Port{Values: []uint16{80, 8080, 443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -588,7 +866,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
srcPort: &fw.Port{Values: []uint16{12345, 12346, 12347}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -604,7 +882,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
srcPort: &fw.Port{Values: []uint16{12345}},
|
srcPort: &fw.Port{Values: []uint16{12345}},
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
@@ -621,7 +899,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -640,7 +918,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 7999,
|
dstPort: 7999,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -659,7 +937,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{
|
srcPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -678,7 +956,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
srcPort: &fw.Port{
|
srcPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -700,7 +978,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8100,
|
dstPort: 8100,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -719,7 +997,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 5060,
|
dstPort: 5060,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolUDP,
|
proto: fw.ProtocolUDP,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -738,7 +1016,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 8080,
|
dstPort: 8080,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
dstPort: &fw.Port{
|
dstPort: &fw.Port{
|
||||||
IsRange: true,
|
IsRange: true,
|
||||||
@@ -757,7 +1035,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 443,
|
dstPort: 443,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -773,7 +1051,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
dstPort: 80,
|
dstPort: 80,
|
||||||
rule: rule{
|
rule: rule{
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
},
|
},
|
||||||
@@ -791,18 +1069,160 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
netip.MustParsePrefix("100.10.0.0/16"),
|
netip.MustParsePrefix("100.10.0.0/16"),
|
||||||
netip.MustParsePrefix("172.16.0.0/16"),
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
},
|
},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
},
|
},
|
||||||
shouldPass: false,
|
shouldPass: false,
|
||||||
},
|
},
|
||||||
|
|
||||||
|
{
|
||||||
|
name: "Drop empty destination set",
|
||||||
|
srcIP: "172.16.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
},
|
||||||
|
dest: fw.Network{Set: fw.Set{}},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Accept TCP traffic outside drop port range",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 7999,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
dstPort: &fw.Port{IsRange: true, Values: []uint16{8000, 8100}},
|
||||||
|
action: fw.ActionDrop,
|
||||||
|
},
|
||||||
|
shouldPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow TCP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 443,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Allow UDP traffic without port specification",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 53,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "TCP packet doesn't match UDP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDP packet doesn't match TCP filter with same port",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
srcPort: 12345,
|
||||||
|
dstPort: 80,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match TCP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolTCP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMP packet doesn't match UDP filter",
|
||||||
|
srcIP: "100.10.0.1",
|
||||||
|
dstIP: "192.168.1.100",
|
||||||
|
proto: fw.ProtocolICMP,
|
||||||
|
rule: rule{
|
||||||
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
|
proto: fw.ProtocolUDP,
|
||||||
|
action: fw.ActionAccept,
|
||||||
|
},
|
||||||
|
shouldPass: false,
|
||||||
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
if tc.rule.action == fw.ActionDrop {
|
||||||
|
// add general accept rule to test drop rule
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.ProtocolALL,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
tc.rule.sources,
|
tc.rule.sources,
|
||||||
tc.rule.dest,
|
tc.rule.dest,
|
||||||
tc.rule.proto,
|
tc.rule.proto,
|
||||||
@@ -817,12 +1237,12 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
})
|
})
|
||||||
|
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
|
|
||||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||||
// to the forwarder
|
// to the forwarder
|
||||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
require.Equal(t, tc.shouldPass, isAllowed)
|
require.Equal(t, tc.shouldPass, isAllowed)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -835,7 +1255,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
rules []struct {
|
rules []struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -856,7 +1276,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
name: "Drop rules take precedence over accept",
|
name: "Drop rules take precedence over accept",
|
||||||
rules: []struct {
|
rules: []struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -865,7 +1285,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Accept rule added first
|
// Accept rule added first
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80, 443}},
|
dstPort: &fw.Port{Values: []uint16{80, 443}},
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
@@ -873,7 +1293,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Drop rule added second but should be evaluated first
|
// Drop rule added second but should be evaluated first
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -911,7 +1331,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
name: "Multiple drop rules take precedence",
|
name: "Multiple drop rules take precedence",
|
||||||
rules: []struct {
|
rules: []struct {
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
dest netip.Prefix
|
dest fw.Network
|
||||||
proto fw.Protocol
|
proto fw.Protocol
|
||||||
srcPort *fw.Port
|
srcPort *fw.Port
|
||||||
dstPort *fw.Port
|
dstPort *fw.Port
|
||||||
@@ -920,14 +1340,14 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Accept all
|
// Accept all
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
sources: []netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
dest: netip.MustParsePrefix("0.0.0.0/0"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
proto: fw.ProtocolALL,
|
proto: fw.ProtocolALL,
|
||||||
action: fw.ActionAccept,
|
action: fw.ActionAccept,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
// Drop specific port
|
// Drop specific port
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{443}},
|
dstPort: &fw.Port{Values: []uint16{443}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -935,7 +1355,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
{
|
{
|
||||||
// Drop different port
|
// Drop different port
|
||||||
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
sources: []netip.Prefix{netip.MustParsePrefix("100.10.0.0/16")},
|
||||||
dest: netip.MustParsePrefix("192.168.1.0/24"),
|
dest: fw.Network{Prefix: netip.MustParsePrefix("192.168.1.0/24")},
|
||||||
proto: fw.ProtocolTCP,
|
proto: fw.ProtocolTCP,
|
||||||
dstPort: &fw.Port{Values: []uint16{80}},
|
dstPort: &fw.Port{Values: []uint16{80}},
|
||||||
action: fw.ActionDrop,
|
action: fw.ActionDrop,
|
||||||
@@ -985,6 +1405,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
var rules []fw.Rule
|
var rules []fw.Rule
|
||||||
for _, r := range tc.rules {
|
for _, r := range tc.rules {
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
r.sources,
|
r.sources,
|
||||||
r.dest,
|
r.dest,
|
||||||
r.proto,
|
r.proto,
|
||||||
@@ -1004,12 +1425,62 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
for i, p := range tc.packets {
|
for i, p := range tc.packets {
|
||||||
srcIP := net.ParseIP(p.srcIP)
|
srcIP := netip.MustParseAddr(p.srcIP)
|
||||||
dstIP := net.ParseIP(p.dstIP)
|
dstIP := netip.MustParseAddr(p.dstIP)
|
||||||
|
|
||||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestRouteACLSet(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
|
// Add rule that uses the set (initially empty)
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Set: set},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.100")
|
||||||
|
|
||||||
|
// Check that traffic is dropped (empty set shouldn't match anything)
|
||||||
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.False(t, isAllowed, "Empty set should not allow any traffic")
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, []netip.Prefix{netip.MustParsePrefix("192.168.1.0/24")})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Now the packet should be allowed
|
||||||
|
_, isAllowed = manager.routeACLsPass(srcIP, dstIP, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.True(t, isAllowed, "After set update, traffic to the added network should be allowed")
|
||||||
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,9 +19,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
@@ -62,7 +66,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -82,7 +86,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -92,9 +96,8 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -116,26 +119,25 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := netip.MustParseAddr("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule 2"
|
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -149,7 +151,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -160,7 +162,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
in bool
|
in bool
|
||||||
expDir fw.RuleDirection
|
expDir fw.RuleDirection
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
dPort uint16
|
dPort uint16
|
||||||
hook func([]byte) bool
|
hook func([]byte) bool
|
||||||
expectedID string
|
expectedID string
|
||||||
@@ -169,7 +171,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Outgoing UDP Packet Hook",
|
name: "Test Outgoing UDP Packet Hook",
|
||||||
in: false,
|
in: false,
|
||||||
expDir: fw.RuleDirectionOUT,
|
expDir: fw.RuleDirectionOUT,
|
||||||
ip: net.IPv4(10, 168, 0, 1),
|
ip: netip.MustParseAddr("10.168.0.1"),
|
||||||
dPort: 8000,
|
dPort: 8000,
|
||||||
hook: func([]byte) bool { return true },
|
hook: func([]byte) bool { return true },
|
||||||
},
|
},
|
||||||
@@ -177,7 +179,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Incoming UDP Packet Hook",
|
name: "Test Incoming UDP Packet Hook",
|
||||||
in: true,
|
in: true,
|
||||||
expDir: fw.RuleDirectionIN,
|
expDir: fw.RuleDirectionIN,
|
||||||
ip: net.IPv6loopback,
|
ip: netip.MustParseAddr("::1"),
|
||||||
dPort: 9000,
|
dPort: 9000,
|
||||||
hook: func([]byte) bool { return false },
|
hook: func([]byte) bool { return false },
|
||||||
},
|
},
|
||||||
@@ -187,18 +189,18 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule PeerRule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
for _, rule := range manager.incomingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -206,12 +208,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -236,7 +238,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -246,9 +248,8 @@ func TestManagerReset(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -279,7 +280,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -292,9 +293,8 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -328,7 +328,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes()) {
|
if m.dropFilter(buf.Bytes(), 0) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -347,7 +347,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface, false)
|
manager, err := Create(iface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@@ -357,7 +357,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
|
|
||||||
// Add a UDP packet hook
|
// Add a UDP packet hook
|
||||||
hookFunc := func(data []byte) bool { return true }
|
hookFunc := func(data []byte) bool { return true }
|
||||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
@@ -393,7 +393,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -401,7 +401,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(16, 32),
|
Mask: net.CIDRMask(16, 32),
|
||||||
}
|
}
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
}()
|
}()
|
||||||
@@ -423,7 +423,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
hookCalled := false
|
hookCalled := false
|
||||||
hookID := manager.AddUDPPacketHook(
|
hookID := manager.AddUDPPacketHook(
|
||||||
false,
|
false,
|
||||||
net.ParseIP("100.10.0.100"),
|
netip.MustParseAddr("100.10.0.100"),
|
||||||
53,
|
53,
|
||||||
func([]byte) bool {
|
func([]byte) bool {
|
||||||
hookCalled = true
|
hookCalled = true
|
||||||
@@ -458,7 +458,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test hook gets called
|
// Test hook gets called
|
||||||
result := manager.processOutgoingHooks(buf.Bytes())
|
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.True(t, result)
|
require.True(t, result)
|
||||||
require.True(t, hookCalled)
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
@@ -468,7 +468,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result = manager.processOutgoingHooks(buf.Bytes())
|
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.False(t, result)
|
require.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,7 +479,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@@ -494,7 +494,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
@@ -506,7 +506,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -515,7 +515,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@@ -534,8 +534,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up packet parameters
|
// Set up packet parameters
|
||||||
srcIP := net.ParseIP("100.10.0.1")
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
dstIP := net.ParseIP("100.10.0.100")
|
dstIP := netip.MustParseAddr("100.10.0.100")
|
||||||
srcPort := uint16(51334)
|
srcPort := uint16(51334)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
@@ -543,8 +543,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
outboundIPv4 := &layers.IPv4{
|
outboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: srcIP,
|
SrcIP: srcIP.AsSlice(),
|
||||||
DstIP: dstIP,
|
DstIP: dstIP.AsSlice(),
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
outboundUDP := &layers.UDP{
|
outboundUDP := &layers.UDP{
|
||||||
@@ -569,15 +569,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Process outbound packet and verify connection tracking
|
// Process outbound packet and verify connection tracking
|
||||||
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
|
||||||
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||||
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||||
|
|
||||||
@@ -585,8 +585,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
inboundIPv4 := &layers.IPv4{
|
inboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: dstIP, // Original destination is now source
|
SrcIP: dstIP.AsSlice(), // Original destination is now source
|
||||||
DstIP: srcIP, // Original source is now destination
|
DstIP: srcIP.AsSlice(), // Original source is now destination
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
inboundUDP := &layers.UDP{
|
inboundUDP := &layers.UDP{
|
||||||
@@ -636,7 +636,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@@ -685,7 +685,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new outbound connection for invalid tests
|
// Create a new outbound connection for invalid tests
|
||||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
for _, tc := range invalidCases {
|
for _, tc := range invalidCases {
|
||||||
@@ -707,8 +707,208 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes())
|
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestUpdateSetMerge(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
|
initialPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Set: set},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
|
// Update the set with initial prefixes
|
||||||
|
err = manager.UpdateSet(set, initialPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Test initial prefixes work
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
dstIP1 := netip.MustParseAddr("10.0.0.100")
|
||||||
|
dstIP2 := netip.MustParseAddr("192.168.1.100")
|
||||||
|
dstIP3 := netip.MustParseAddr("172.16.0.100")
|
||||||
|
|
||||||
|
_, isAllowed1 := manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed2 := manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed3 := manager.routeACLsPass(srcIP, dstIP3, fw.ProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should be allowed")
|
||||||
|
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should be allowed")
|
||||||
|
require.False(t, isAllowed3, "Traffic to 172.16.0.100 should be denied")
|
||||||
|
|
||||||
|
newPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("172.16.0.0/16"),
|
||||||
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, newPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check that all original prefixes are still included
|
||||||
|
_, isAllowed1 = manager.routeACLsPass(srcIP, dstIP1, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed2 = manager.routeACLsPass(srcIP, dstIP2, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.True(t, isAllowed1, "Traffic to 10.0.0.100 should still be allowed after update")
|
||||||
|
require.True(t, isAllowed2, "Traffic to 192.168.1.100 should still be allowed after update")
|
||||||
|
|
||||||
|
// Check that new prefixes are included
|
||||||
|
dstIP4 := netip.MustParseAddr("172.16.1.100")
|
||||||
|
dstIP5 := netip.MustParseAddr("10.1.0.50")
|
||||||
|
|
||||||
|
_, isAllowed4 := manager.routeACLsPass(srcIP, dstIP4, fw.ProtocolTCP, 12345, 80)
|
||||||
|
_, isAllowed5 := manager.routeACLsPass(srcIP, dstIP5, fw.ProtocolTCP, 12345, 80)
|
||||||
|
|
||||||
|
require.True(t, isAllowed4, "Traffic to new prefix 172.16.0.0/16 should be allowed")
|
||||||
|
require.True(t, isAllowed5, "Traffic to new prefix 10.1.0.0/24 should be allowed")
|
||||||
|
|
||||||
|
// Verify the rule has all prefixes
|
||||||
|
manager.mutex.RLock()
|
||||||
|
foundRule := false
|
||||||
|
for _, r := range manager.routeRules {
|
||||||
|
if r.id == rule.ID() {
|
||||||
|
foundRule = true
|
||||||
|
require.Len(t, r.destinations, len(initialPrefixes)+len(newPrefixes),
|
||||||
|
"Rule should have all prefixes merged")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
require.True(t, foundRule, "Rule should be found")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateSetDeduplication(t *testing.T) {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
}
|
||||||
|
|
||||||
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, manager.Close(nil))
|
||||||
|
})
|
||||||
|
|
||||||
|
set := fw.NewDomainSet(domain.List{"example.org"})
|
||||||
|
|
||||||
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
|
[]netip.Prefix{netip.MustParsePrefix("0.0.0.0/0")},
|
||||||
|
fw.Network{Set: set},
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionAccept,
|
||||||
|
)
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NotNil(t, rule)
|
||||||
|
|
||||||
|
initialPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // Duplicate
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"), // Duplicate
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, initialPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check the internal state for deduplication
|
||||||
|
manager.mutex.RLock()
|
||||||
|
foundRule := false
|
||||||
|
for _, r := range manager.routeRules {
|
||||||
|
if r.id == rule.ID() {
|
||||||
|
foundRule = true
|
||||||
|
// Should have deduplicated to 2 prefixes
|
||||||
|
require.Len(t, r.destinations, 2, "Duplicate prefixes should be removed")
|
||||||
|
|
||||||
|
// Check the prefixes are correct
|
||||||
|
expectedPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"),
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"),
|
||||||
|
}
|
||||||
|
for i, prefix := range expectedPrefixes {
|
||||||
|
require.True(t, r.destinations[i] == prefix,
|
||||||
|
"Prefix should match expected value")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
require.True(t, foundRule, "Rule should be found")
|
||||||
|
|
||||||
|
// Test with overlapping prefixes of different sizes
|
||||||
|
overlappingPrefixes := []netip.Prefix{
|
||||||
|
netip.MustParsePrefix("10.0.0.0/16"), // More general
|
||||||
|
netip.MustParsePrefix("10.0.0.0/24"), // More specific (already exists)
|
||||||
|
netip.MustParsePrefix("192.168.0.0/16"), // More general
|
||||||
|
netip.MustParsePrefix("192.168.1.0/24"), // More specific (already exists)
|
||||||
|
}
|
||||||
|
|
||||||
|
err = manager.UpdateSet(set, overlappingPrefixes)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
// Check that all prefixes are included (no deduplication of overlapping prefixes)
|
||||||
|
manager.mutex.RLock()
|
||||||
|
for _, r := range manager.routeRules {
|
||||||
|
if r.id == rule.ID() {
|
||||||
|
// Should have all 4 prefixes (2 original + 2 new more general ones)
|
||||||
|
require.Len(t, r.destinations, 4,
|
||||||
|
"Overlapping prefixes should not be deduplicated")
|
||||||
|
|
||||||
|
// Verify they're sorted correctly (more specific prefixes should come first)
|
||||||
|
prefixes := make([]string, 0, len(r.destinations))
|
||||||
|
for _, p := range r.destinations {
|
||||||
|
prefixes = append(prefixes, p.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check sorted order
|
||||||
|
require.Equal(t, []string{
|
||||||
|
"10.0.0.0/16",
|
||||||
|
"10.0.0.0/24",
|
||||||
|
"192.168.0.0/16",
|
||||||
|
"192.168.1.0/24",
|
||||||
|
}, prefixes, "Prefixes should be sorted")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
manager.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Test functionality with all prefixes
|
||||||
|
testCases := []struct {
|
||||||
|
dstIP netip.Addr
|
||||||
|
expected bool
|
||||||
|
desc string
|
||||||
|
}{
|
||||||
|
{netip.MustParseAddr("10.0.0.100"), true, "IP in both /16 and /24"},
|
||||||
|
{netip.MustParseAddr("10.0.1.100"), true, "IP only in /16"},
|
||||||
|
{netip.MustParseAddr("192.168.1.100"), true, "IP in both /16 and /24"},
|
||||||
|
{netip.MustParseAddr("192.168.2.100"), true, "IP only in /16"},
|
||||||
|
{netip.MustParseAddr("172.16.0.100"), false, "IP not in any prefix"},
|
||||||
|
}
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
|
for _, tc := range testCases {
|
||||||
|
_, isAllowed := manager.routeACLsPass(srcIP, tc.dstIP, fw.ProtocolTCP, 12345, 80)
|
||||||
|
require.Equal(t, tc.expected, isAllowed, tc.desc)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
|
|||||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||||
if params.Logger == nil {
|
if params.Logger == nil {
|
||||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
params.Logger = getLogger()
|
||||||
}
|
}
|
||||||
|
|
||||||
mux := &UDPMuxDefault{
|
mux := &UDPMuxDefault{
|
||||||
@@ -455,3 +455,9 @@ func newBufferHolder(size int) *bufferHolder {
|
|||||||
buf: make([]byte, size),
|
buf: make([]byte, size),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLogger() logging.LeveledLogger {
|
||||||
|
fac := logging.NewDefaultLoggerFactory()
|
||||||
|
//fac.Writer = log.StandardLogger().Writer()
|
||||||
|
return fac.NewLogger("ice")
|
||||||
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ type UniversalUDPMuxParams struct {
|
|||||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||||
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
||||||
if params.Logger == nil {
|
if params.Logger == nil {
|
||||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
params.Logger = getLogger()
|
||||||
}
|
}
|
||||||
if params.XORMappedAddrCacheTTL == 0 {
|
if params.XORMappedAddrCacheTTL == 0 {
|
||||||
params.XORMappedAddrCacheTTL = time.Second * 25
|
params.XORMappedAddrCacheTTL = time.Second * 25
|
||||||
|
|||||||
@@ -357,7 +357,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.NetbirdFwmark
|
return nbnet.ControlPlaneMark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@@ -10,16 +11,16 @@ import (
|
|||||||
// PacketFilter interface for firewall abilities
|
// PacketFilter interface for firewall abilities
|
||||||
type PacketFilter interface {
|
type PacketFilter interface {
|
||||||
// DropOutgoing filter outgoing packets from host to external destinations
|
// DropOutgoing filter outgoing packets from host to external destinations
|
||||||
DropOutgoing(packetData []byte) bool
|
DropOutgoing(packetData []byte, size int) bool
|
||||||
|
|
||||||
// DropIncoming filter incoming packets from external sources to host
|
// DropIncoming filter incoming packets from external sources to host
|
||||||
DropIncoming(packetData []byte) bool
|
DropIncoming(packetData []byte, size int) bool
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||||
// Hook function receives raw network packet data as argument.
|
// Hook function receives raw network packet data as argument.
|
||||||
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
|
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||||
n--
|
n--
|
||||||
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
filteredBufs := make([][]byte, 0, len(bufs))
|
||||||
dropped := 0
|
dropped := 0
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
if !filter.DropIncoming(buf[offset:]) {
|
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
filteredBufs = append(filteredBufs, buf)
|
||||||
dropped++
|
dropped++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||||
|
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
|
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return 1, nil
|
return 1, nil
|
||||||
})
|
})
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
|
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package mocks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
net "net"
|
net "net"
|
||||||
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddUDPPacketHook mocks base method.
|
// AddUDPPacketHook mocks base method.
|
||||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||||
ret0, _ := ret[0].(string)
|
ret0, _ := ret[0].(string)
|
||||||
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// DropIncoming mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// DropIncoming indicates an expected call of DropIncoming.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// DropOutgoing mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook mocks base method.
|
// RemovePacketHook mocks base method.
|
||||||
|
|||||||
@@ -18,7 +18,7 @@ func (r RuleID) ID() string {
|
|||||||
|
|
||||||
func GenerateRouteRuleKey(
|
func GenerateRouteRuleKey(
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination manager.Network,
|
||||||
proto manager.Protocol,
|
proto manager.Protocol,
|
||||||
sPort *manager.Port,
|
sPort *manager.Port,
|
||||||
dPort *manager.Port,
|
dPort *manager.Port,
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
"github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/ssh"
|
"github.com/netbirdio/netbird/client/ssh"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -25,7 +26,12 @@ var ErrSourceRangesEmpty = errors.New("sources range is empty")
|
|||||||
|
|
||||||
// Manager is a ACL rules manager
|
// Manager is a ACL rules manager
|
||||||
type Manager interface {
|
type Manager interface {
|
||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool)
|
||||||
|
}
|
||||||
|
|
||||||
|
type protoMatch struct {
|
||||||
|
ips map[string]int
|
||||||
|
policyID []byte
|
||||||
}
|
}
|
||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
@@ -48,7 +54,7 @@ func NewDefaultManager(fm firewall.Manager) *DefaultManager {
|
|||||||
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
// ApplyFiltering firewall rules to the local firewall manager processed by ACL policy.
|
||||||
//
|
//
|
||||||
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
// If allowByDefault is true it appends allow ALL traffic rules to input and output chains.
|
||||||
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap, dnsRouteFeatureFlag bool) {
|
||||||
d.mutex.Lock()
|
d.mutex.Lock()
|
||||||
defer d.mutex.Unlock()
|
defer d.mutex.Unlock()
|
||||||
|
|
||||||
@@ -77,7 +83,7 @@ func (d *DefaultManager) ApplyFiltering(networkMap *mgmProto.NetworkMap) {
|
|||||||
log.Errorf("failed to set legacy management flag: %v", err)
|
log.Errorf("failed to set legacy management flag: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules); err != nil {
|
if err := d.applyRouteACLs(networkMap.RoutesFirewallRules, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("Failed to apply route ACLs: %v", err)
|
log.Errorf("Failed to apply route ACLs: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -171,16 +177,16 @@ func (d *DefaultManager) applyPeerACLs(networkMap *mgmProto.NetworkMap) {
|
|||||||
d.peerRulesPairs = newRulePairs
|
d.peerRulesPairs = newRulePairs
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) error {
|
func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule, dynamicResolver bool) error {
|
||||||
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
newRouteRules := make(map[id.RuleID]struct{}, len(rules))
|
||||||
var merr *multierror.Error
|
var merr *multierror.Error
|
||||||
|
|
||||||
// Apply new rules - firewall manager will return existing rule ID if already present
|
// Apply new rules - firewall manager will return existing rule ID if already present
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
id, err := d.applyRouteACL(rule)
|
id, err := d.applyRouteACL(rule, dynamicResolver)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
if errors.Is(err, ErrSourceRangesEmpty) {
|
if errors.Is(err, ErrSourceRangesEmpty) {
|
||||||
log.Debugf("skipping empty rule with destination %s: %v", rule.Destination, err)
|
log.Debugf("skipping empty sources rule with destination %s: %v", rule.Destination, err)
|
||||||
} else {
|
} else {
|
||||||
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
merr = multierror.Append(merr, fmt.Errorf("add route rule: %w", err))
|
||||||
}
|
}
|
||||||
@@ -203,7 +209,7 @@ func (d *DefaultManager) applyRouteACLs(rules []*mgmProto.RouteFirewallRule) err
|
|||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.RuleID, error) {
|
func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule, dynamicResolver bool) (id.RuleID, error) {
|
||||||
if len(rule.SourceRanges) == 0 {
|
if len(rule.SourceRanges) == 0 {
|
||||||
return "", ErrSourceRangesEmpty
|
return "", ErrSourceRangesEmpty
|
||||||
}
|
}
|
||||||
@@ -217,15 +223,9 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
|||||||
sources = append(sources, source)
|
sources = append(sources, source)
|
||||||
}
|
}
|
||||||
|
|
||||||
var destination netip.Prefix
|
destination, err := determineDestination(rule, dynamicResolver, sources)
|
||||||
if rule.IsDynamic {
|
if err != nil {
|
||||||
destination = getDefault(sources[0])
|
return "", fmt.Errorf("determine destination: %w", err)
|
||||||
} else {
|
|
||||||
var err error
|
|
||||||
destination, err = netip.ParsePrefix(rule.Destination)
|
|
||||||
if err != nil {
|
|
||||||
return "", fmt.Errorf("parse destination: %w", err)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
protocol, err := convertToFirewallProtocol(rule.Protocol)
|
||||||
@@ -240,7 +240,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
|||||||
|
|
||||||
dPorts := convertPortInfo(rule.PortInfo)
|
dPorts := convertPortInfo(rule.PortInfo)
|
||||||
|
|
||||||
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
|
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("add route rule: %w", err)
|
return "", fmt.Errorf("add route rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -281,7 +281,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
|
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||||
return ruleID, rulesPair, nil
|
return ruleID, rulesPair, nil
|
||||||
}
|
}
|
||||||
@@ -289,11 +289,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
}
|
}
|
||||||
@@ -322,14 +322,14 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
func (d *DefaultManager) addInRules(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -338,18 +338,18 @@ func (d *DefaultManager) addInRules(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
func (d *DefaultManager) addOutRules(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -364,9 +364,8 @@ func (d *DefaultManager) getPeerRuleID(
|
|||||||
direction int,
|
direction int,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
comment string,
|
|
||||||
) id.RuleID {
|
) id.RuleID {
|
||||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||||
if port != nil {
|
if port != nil {
|
||||||
idStr += port.String()
|
idStr += port.String()
|
||||||
}
|
}
|
||||||
@@ -389,10 +388,8 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type protoMatch map[mgmProto.RuleProtocol]map[string]int
|
in := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||||
|
out := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||||
in := protoMatch{}
|
|
||||||
out := protoMatch{}
|
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
// trace which type of protocols was squashed
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
squashedRules := []*mgmProto.FirewallRule{}
|
||||||
@@ -405,14 +402,18 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
// 2. Any of rule contains Port.
|
// 2. Any of rule contains Port.
|
||||||
//
|
//
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
|
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||||
if drop {
|
if drop {
|
||||||
protocols[r.Protocol] = map[string]int{}
|
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = map[string]int{}
|
protocols[r.Protocol] = &protoMatch{
|
||||||
|
ips: map[string]int{},
|
||||||
|
// store the first encountered PolicyID for this protocol
|
||||||
|
policyID: r.PolicyID,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// special case, when we receive this all network IP address
|
// special case, when we receive this all network IP address
|
||||||
@@ -424,7 +425,7 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ipset := protocols[r.Protocol]
|
ipset := protocols[r.Protocol].ips
|
||||||
|
|
||||||
if _, ok := ipset[r.PeerIP]; ok {
|
if _, ok := ipset[r.PeerIP]; ok {
|
||||||
return
|
return
|
||||||
@@ -450,9 +451,10 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
mgmProto.RuleProtocol_UDP,
|
mgmProto.RuleProtocol_UDP,
|
||||||
}
|
}
|
||||||
|
|
||||||
squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
|
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
|
||||||
for _, protocol := range protocolOrders {
|
for _, protocol := range protocolOrders {
|
||||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
match, ok := matches[protocol]
|
||||||
|
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
|
||||||
// don't squash if :
|
// don't squash if :
|
||||||
// 1. Rules not cover all peers in the network
|
// 1. Rules not cover all peers in the network
|
||||||
// 2. Rules cover only one peer in the network.
|
// 2. Rules cover only one peer in the network.
|
||||||
@@ -465,6 +467,7 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
Direction: direction,
|
Direction: direction,
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: protocol,
|
Protocol: protocol,
|
||||||
|
PolicyID: match.policyID,
|
||||||
})
|
})
|
||||||
squashedProtocols[protocol] = struct{}{}
|
squashedProtocols[protocol] = struct{}{}
|
||||||
|
|
||||||
@@ -493,9 +496,9 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
// if we also have other not squashed rules.
|
// if we also have other not squashed rules.
|
||||||
for i, r := range networkMap.FirewallRules {
|
for i, r := range networkMap.FirewallRules {
|
||||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
if _, ok := squashedProtocols[r.Protocol]; ok {
|
||||||
if m, ok := in[r.Protocol]; ok && m[r.PeerIP] == i {
|
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||||
continue
|
continue
|
||||||
} else if m, ok := out[r.Protocol]; ok && m[r.PeerIP] == i {
|
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -572,6 +575,33 @@ func convertPortInfo(portInfo *mgmProto.PortInfo) *firewall.Port {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func determineDestination(rule *mgmProto.RouteFirewallRule, dynamicResolver bool, sources []netip.Prefix) (firewall.Network, error) {
|
||||||
|
var destination firewall.Network
|
||||||
|
|
||||||
|
if rule.IsDynamic {
|
||||||
|
if dynamicResolver {
|
||||||
|
if len(rule.Domains) > 0 {
|
||||||
|
destination.Set = firewall.NewDomainSet(domain.FromPunycodeList(rule.Domains))
|
||||||
|
} else {
|
||||||
|
// isDynamic is set but no domains = outdated management server
|
||||||
|
log.Warn("connected to an older version of management server (no domains in rules), using default destination")
|
||||||
|
destination.Prefix = getDefault(sources[0])
|
||||||
|
}
|
||||||
|
} else {
|
||||||
|
// client resolves DNS, we (router) don't know the destination
|
||||||
|
destination.Prefix = getDefault(sources[0])
|
||||||
|
}
|
||||||
|
return destination, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
prefix, err := netip.ParsePrefix(rule.Destination)
|
||||||
|
if err != nil {
|
||||||
|
return destination, fmt.Errorf("parse destination: %w", err)
|
||||||
|
}
|
||||||
|
destination.Prefix = prefix
|
||||||
|
return destination, nil
|
||||||
|
}
|
||||||
|
|
||||||
func getDefault(prefix netip.Prefix) netip.Prefix {
|
func getDefault(prefix netip.Prefix) netip.Prefix {
|
||||||
if prefix.Addr().Is6() {
|
if prefix.Addr().Is6() {
|
||||||
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
return netip.PrefixFrom(netip.IPv6Unspecified(), 0)
|
||||||
|
|||||||
@@ -10,9 +10,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
func TestDefaultManager(t *testing.T) {
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
@@ -52,7 +55,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@@ -63,7 +66,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
t.Run("apply firewall rules", func(t *testing.T) {
|
t.Run("apply firewall rules", func(t *testing.T) {
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if len(acl.peerRulesPairs) != 2 {
|
||||||
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
t.Errorf("firewall rules not applied: %v", acl.peerRulesPairs)
|
||||||
@@ -89,7 +92,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
// we should have one old and one new rule in the existed rules
|
// we should have one old and one new rule in the existed rules
|
||||||
if len(acl.peerRulesPairs) != 2 {
|
if len(acl.peerRulesPairs) != 2 {
|
||||||
@@ -113,13 +116,13 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
networkMap.FirewallRules = networkMap.FirewallRules[:0]
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = true
|
networkMap.FirewallRulesIsEmpty = true
|
||||||
if acl.ApplyFiltering(networkMap); len(acl.peerRulesPairs) != 0 {
|
if acl.ApplyFiltering(networkMap, false); len(acl.peerRulesPairs) != 0 {
|
||||||
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
t.Errorf("rules should be empty if FirewallRulesIsEmpty is set, got: %v", len(acl.peerRulesPairs))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
networkMap.FirewallRulesIsEmpty = false
|
networkMap.FirewallRulesIsEmpty = false
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
if len(acl.peerRulesPairs) != 1 {
|
if len(acl.peerRulesPairs) != 1 {
|
||||||
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
t.Errorf("rules should contain 1 rules if FirewallRulesIsEmpty is not set, got: %v", len(acl.peerRulesPairs))
|
||||||
return
|
return
|
||||||
@@ -346,7 +349,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@@ -356,7 +359,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
}(fw)
|
}(fw)
|
||||||
acl := NewDefaultManager(fw)
|
acl := NewDefaultManager(fw)
|
||||||
|
|
||||||
acl.ApplyFiltering(networkMap)
|
acl.ApplyFiltering(networkMap, false)
|
||||||
|
|
||||||
if len(acl.peerRulesPairs) != 3 {
|
if len(acl.peerRulesPairs) != 3 {
|
||||||
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
t.Errorf("expect 3 rules (last must be SSH), got: %d", len(acl.peerRulesPairs))
|
||||||
|
|||||||
@@ -94,12 +94,17 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
p.codeVerifier = codeVerifier
|
p.codeVerifier = codeVerifier
|
||||||
|
|
||||||
codeChallenge := createCodeChallenge(codeVerifier)
|
codeChallenge := createCodeChallenge(codeVerifier)
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(
|
|
||||||
state,
|
params := []oauth2.AuthCodeOption{
|
||||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
)
|
}
|
||||||
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||||
|
|
||||||
return AuthFlowInfo{
|
return AuthFlowInfo{
|
||||||
VerificationURIComplete: authURL,
|
VerificationURIComplete: authURL,
|
||||||
|
|||||||
49
client/internal/auth/pkce_flow_test.go
Normal file
49
client/internal/auth/pkce_flow_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPromptLogin(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
prompt bool
|
||||||
|
}{
|
||||||
|
{"PromptLogin", true},
|
||||||
|
{"NoPromptLogin", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
config := internal.PKCEAuthProviderConfig{
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Audience: "test-audience",
|
||||||
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
Scope: "openid email profile",
|
||||||
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
|
UseIDToken: true,
|
||||||
|
DisablePromptLogin: !tc.prompt,
|
||||||
|
}
|
||||||
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
|
||||||
|
}
|
||||||
|
authInfo, err := pkce.RequestAuthInfo(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
|
}
|
||||||
|
pattern := "prompt=login"
|
||||||
|
if tc.prompt {
|
||||||
|
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
||||||
|
} else {
|
||||||
|
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -349,6 +349,25 @@ func (c *ConnectClient) Engine() *Engine {
|
|||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLatestNetworkMap returns the latest network map from the engine.
|
||||||
|
func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||||
|
engine := c.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, errors.New("engine is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
networkMap, err := engine.GetLatestNetworkMap()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get latest network map: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if networkMap == nil {
|
||||||
|
return nil, errors.New("network map is not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return networkMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Status returns the current client status
|
// Status returns the current client status
|
||||||
func (c *ConnectClient) Status() StatusType {
|
func (c *ConnectClient) Status() StatusType {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
|
|||||||
1022
client/internal/debug/debug.go
Normal file
1022
client/internal/debug/debug.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,8 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android
|
||||||
|
|
||||||
package server
|
package debug
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/zip"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -14,36 +13,31 @@ import (
|
|||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// addFirewallRules collects and adds firewall rules to the archive
|
// addFirewallRules collects and adds firewall rules to the archive
|
||||||
func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
func (g *BundleGenerator) addFirewallRules() error {
|
||||||
log.Info("Collecting firewall rules")
|
log.Info("Collecting firewall rules")
|
||||||
// Collect and add iptables rules
|
|
||||||
iptablesRules, err := collectIPTablesRules()
|
iptablesRules, err := collectIPTablesRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if req.GetAnonymize() {
|
if g.anonymize {
|
||||||
iptablesRules = anonymizer.AnonymizeString(iptablesRules)
|
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||||
}
|
}
|
||||||
if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect and add nftables rules
|
|
||||||
nftablesRules, err := collectNFTablesRules()
|
nftablesRules, err := collectNFTablesRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to collect nftables rules: %v", err)
|
log.Warnf("Failed to collect nftables rules: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if req.GetAnonymize() {
|
if g.anonymize {
|
||||||
nftablesRules = anonymizer.AnonymizeString(nftablesRules)
|
nftablesRules = g.anonymizer.AnonymizeString(nftablesRules)
|
||||||
}
|
}
|
||||||
if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
if err := g.addFileToZip(strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
||||||
log.Warnf("Failed to add nftables rules to bundle: %v", err)
|
log.Warnf("Failed to add nftables rules to bundle: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -65,16 +59,23 @@ func collectIPTablesRules() (string, error) {
|
|||||||
builder.WriteString("\n")
|
builder.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then get verbose statistics for each table
|
// Collect ipset information
|
||||||
|
ipsetOutput, err := collectIPSets()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to collect ipset information: %v", err)
|
||||||
|
} else {
|
||||||
|
builder.WriteString("=== ipset list output ===\n")
|
||||||
|
builder.WriteString(ipsetOutput)
|
||||||
|
builder.WriteString("\n")
|
||||||
|
}
|
||||||
|
|
||||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||||
|
|
||||||
// Get list of tables
|
|
||||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||||
|
|
||||||
// Get verbose statistics for the entire table
|
|
||||||
stats, err := getTableStatistics(table)
|
stats, err := getTableStatistics(table)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||||
@@ -87,6 +88,28 @@ func collectIPTablesRules() (string, error) {
|
|||||||
return builder.String(), nil
|
return builder.String(), nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// collectIPSets collects information about ipsets
|
||||||
|
func collectIPSets() (string, error) {
|
||||||
|
cmd := exec.Command("ipset", "list")
|
||||||
|
var stdout, stderr bytes.Buffer
|
||||||
|
cmd.Stdout = &stdout
|
||||||
|
cmd.Stderr = &stderr
|
||||||
|
|
||||||
|
if err := cmd.Run(); err != nil {
|
||||||
|
if strings.Contains(err.Error(), "executable file not found") {
|
||||||
|
return "", fmt.Errorf("ipset command not found: %w", err)
|
||||||
|
}
|
||||||
|
return "", fmt.Errorf("execute ipset list: %w (stderr: %s)", err, stderr.String())
|
||||||
|
}
|
||||||
|
|
||||||
|
ipsets := stdout.String()
|
||||||
|
if strings.TrimSpace(ipsets) == "" {
|
||||||
|
return "No ipsets found", nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return ipsets, nil
|
||||||
|
}
|
||||||
|
|
||||||
// collectIPTablesSave uses iptables-save to get rule definitions
|
// collectIPTablesSave uses iptables-save to get rule definitions
|
||||||
func collectIPTablesSave() (string, error) {
|
func collectIPTablesSave() (string, error) {
|
||||||
cmd := exec.Command("iptables-save")
|
cmd := exec.Command("iptables-save")
|
||||||
@@ -182,12 +205,10 @@ func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format chains
|
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
formatChain(conn, table, chain, &builder)
|
formatChain(conn, table, chain, &builder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format sets
|
|
||||||
if sets, err := conn.GetSets(table); err != nil {
|
if sets, err := conn.GetSets(table); err != nil {
|
||||||
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
|
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
|
||||||
} else if len(sets) > 0 {
|
} else if len(sets) > 0 {
|
||||||
7
client/internal/debug/debug_mobile.go
Normal file
7
client/internal/debug/debug_mobile.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build ios || android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addRoutes() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
8
client/internal/debug/debug_nonlinux.go
Normal file
8
client/internal/debug/debug_nonlinux.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
// collectFirewallRules returns nothing on non-linux systems
|
||||||
|
func (g *BundleGenerator) addFirewallRules() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
25
client/internal/debug/debug_nonmobile.go
Normal file
25
client/internal/debug/debug_nonmobile.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addRoutes() error {
|
||||||
|
routes, err := systemops.GetRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: get routes including nexthop
|
||||||
|
routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
|
||||||
|
routesReader := strings.NewReader(routesContent)
|
||||||
|
if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add routes file to zip: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package debug
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
return listOfDomains
|
return listOfDomains
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
}
|
}
|
||||||
|
|
||||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
c.removeEntry(origPattern, priority)
|
||||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if handler implements SubdomainMatcher interface
|
// Check if handler implements SubdomainMatcher interface
|
||||||
matchSubdomains := false
|
matchSubdomains := false
|
||||||
@@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|||||||
|
|
||||||
pattern = dns.Fqdn(pattern)
|
pattern = dns.Fqdn(pattern)
|
||||||
|
|
||||||
|
c.removeEntry(pattern, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
|
||||||
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
|
||||||
for _, entry := range c.handlers {
|
|
||||||
if strings.EqualFold(entry.Pattern, pattern) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -443,14 +443,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
for _, handler := range handlers {
|
for _, handler := range handlers {
|
||||||
handler.AssertExpectations(t)
|
handler.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify handler exists check
|
|
||||||
for priority, shouldExist := range tt.expectedCalls {
|
|
||||||
if shouldExist {
|
|
||||||
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
|
||||||
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -470,45 +462,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(testQuery, dns.TypeA)
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
|
// Keep track of mocks for the final assertion in Step 4
|
||||||
|
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
|
||||||
|
|
||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state with all three handlers
|
// Test 1: Initial state
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
// Highest priority handler (routeHandler) should be called
|
// Highest priority handler (routeHandler) should be called
|
||||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w1, r)
|
||||||
routeHandler.AssertExpectations(t)
|
routeHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
routeHandler.ExpectedCalls = nil
|
||||||
|
routeHandler.Calls = nil
|
||||||
|
matchHandler.ExpectedCalls = nil
|
||||||
|
matchHandler.Calls = nil
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 2: Remove highest priority handler
|
// Test 2: Remove highest priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||||
assert.True(t, chain.HasHandlers(testDomain))
|
|
||||||
|
|
||||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
// Now middle priority handler (matchHandler) should be called
|
// Now middle priority handler (matchHandler) should be called
|
||||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w2, r)
|
||||||
matchHandler.AssertExpectations(t)
|
matchHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
matchHandler.ExpectedCalls = nil
|
||||||
|
matchHandler.Calls = nil
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 3: Remove middle priority handler
|
// Test 3: Remove middle priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||||
assert.True(t, chain.HasHandlers(testDomain))
|
|
||||||
|
|
||||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
// Now lowest priority handler (defaultHandler) should be called
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w3, r)
|
||||||
defaultHandler.AssertExpectations(t)
|
defaultHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 4: Remove last handler
|
// Test 4: Remove last handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
assert.False(t, chain.HasHandlers(testDomain))
|
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||||
|
|
||||||
|
for _, m := range mocks {
|
||||||
|
m.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||||
@@ -830,3 +846,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addPattern string
|
||||||
|
removePattern string
|
||||||
|
queryPattern string
|
||||||
|
shouldBeRemoved bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact same pattern",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "example.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing with identical patterns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case difference",
|
||||||
|
addPattern: "Example.Com.",
|
||||||
|
removePattern: "EXAMPLE.COM.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with mixed case, removing with uppercase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reversed case difference",
|
||||||
|
addPattern: "EXAMPLE.ORG.",
|
||||||
|
removePattern: "example.org.",
|
||||||
|
queryPattern: "example.org.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with uppercase, removing with lowercase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add wildcard, remove wildcard",
|
||||||
|
addPattern: "*.example.com.",
|
||||||
|
removePattern: "*.example.com.",
|
||||||
|
queryPattern: "sub.example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing with identical wildcard patterns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add wildcard, remove transformed pattern",
|
||||||
|
addPattern: "*.example.net.",
|
||||||
|
removePattern: "example.net.",
|
||||||
|
queryPattern: "sub.example.net.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding with wildcard, removing with non-wildcard pattern",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add transformed pattern, remove wildcard",
|
||||||
|
addPattern: "example.io.",
|
||||||
|
removePattern: "*.example.io.",
|
||||||
|
queryPattern: "example.io.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trailing dot difference",
|
||||||
|
addPattern: "example.dev",
|
||||||
|
removePattern: "example.dev.",
|
||||||
|
queryPattern: "example.dev.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding without trailing dot, removing with trailing dot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reversed trailing dot difference",
|
||||||
|
addPattern: "example.app.",
|
||||||
|
removePattern: "example.app",
|
||||||
|
queryPattern: "example.app.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with trailing dot, removing without trailing dot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case and wildcard",
|
||||||
|
addPattern: "*.Example.Site.",
|
||||||
|
removePattern: "*.EXAMPLE.SITE.",
|
||||||
|
queryPattern: "sub.example.site.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding mixed case wildcard, removing uppercase wildcard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone",
|
||||||
|
addPattern: ".",
|
||||||
|
removePattern: ".",
|
||||||
|
queryPattern: "random.domain.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing root zone",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong domain",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "different.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding one domain, trying to remove a different domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain mismatch",
|
||||||
|
addPattern: "sub.example.com.",
|
||||||
|
removePattern: "example.com.",
|
||||||
|
queryPattern: "sub.example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding subdomain, trying to remove parent domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parent domain mismatch",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "sub.example.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding parent domain, trying to remove subdomain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
handler := &nbdns.MockHandler{}
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// First verify no handler is called before adding any
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
handler.AssertNotCalled(t, "ServeDNS")
|
||||||
|
|
||||||
|
// Add handler
|
||||||
|
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
// Verify handler is called after adding
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Reset mock for the next test
|
||||||
|
handler.ExpectedCalls = nil
|
||||||
|
|
||||||
|
// Remove handler
|
||||||
|
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
// Set up expectations based on whether removal should succeed
|
||||||
|
if !tt.shouldBeRemoved {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test if handler is still called after removal attempt
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
if tt.shouldBeRemoved {
|
||||||
|
handler.AssertNotCalled(t, "ServeDNS",
|
||||||
|
"Handler should not be called after successful removal with pattern %q",
|
||||||
|
tt.removePattern)
|
||||||
|
} else {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
handler.ExpectedCalls = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
@@ -12,8 +14,8 @@ import (
|
|||||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ipv4ReverseZone = ".in-addr.arpa"
|
ipv4ReverseZone = ".in-addr.arpa."
|
||||||
ipv6ReverseZone = ".ip6.arpa"
|
ipv6ReverseZone = ".ip6.arpa."
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
@@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
|
|
||||||
for _, domain := range nsConfig.Domains {
|
for _, domain := range nsConfig.Domains {
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.TrimSuffix(domain, "."),
|
Domain: strings.ToLower(dns.Fqdn(domain)),
|
||||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||||
MatchOnly: matchOnly,
|
MatchOnly: matchOnly,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, dConf.Domain)
|
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
|||||||
@@ -17,15 +17,18 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
userenv = syscall.NewLazyDLL("userenv.dll")
|
userenv = syscall.NewLazyDLL("userenv.dll")
|
||||||
|
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
||||||
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
||||||
|
|
||||||
|
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
||||||
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient`
|
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\DnsPolicyConfig`
|
||||||
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\DnsPolicyConfig\NetBird-Match`
|
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\NetBird-Match`
|
||||||
|
|
||||||
dnsPolicyConfigVersionKey = "Version"
|
dnsPolicyConfigVersionKey = "Version"
|
||||||
dnsPolicyConfigVersionValue = 2
|
dnsPolicyConfigVersionValue = 2
|
||||||
@@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !dConf.MatchOnly {
|
if !dConf.MatchOnly {
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
matchDomains = append(matchDomains, "."+dConf.Domain)
|
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
@@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,10 +143,6 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
|
|||||||
return fmt.Errorf("configure GPO DNS policy: %w", err)
|
return fmt.Errorf("configure GPO DNS policy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
|
||||||
return fmt.Errorf("configure local DNS policy: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := refreshGroupPolicy(); err != nil {
|
if err := refreshGroupPolicy(); err != nil {
|
||||||
log.Warnf("failed to refresh group policy: %v", err)
|
log.Warnf("failed to refresh group policy: %v", err)
|
||||||
}
|
}
|
||||||
@@ -188,6 +191,26 @@ func (r *registryConfigurator) string() string {
|
|||||||
return "registry"
|
return "registry"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) flushDNSCache() error {
|
||||||
|
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
log.Errorf("Recovered from panic in flushDNSCache: %v", rec)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||||
|
if ret == 0 {
|
||||||
|
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||||
|
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("flushed DNS cache")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
@@ -240,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("remove interface registry key: %w", err)
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,6 +71,12 @@ func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
|||||||
|
|
||||||
value, found := d.records.Load(key)
|
value, found := d.records.Load(key)
|
||||||
if !found {
|
if !found {
|
||||||
|
// alternatively check if we have a cname
|
||||||
|
if question.Qtype != dns.TypeCNAME {
|
||||||
|
r.Question[0].Qtype = dns.TypeCNAME
|
||||||
|
return d.lookupRecords(r)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockServer is the mock instance of a dns server
|
// MockServer is the mock instance of a dns server
|
||||||
@@ -13,17 +14,17 @@ type MockServer struct {
|
|||||||
InitializeFunc func() error
|
InitializeFunc func() error
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
RegisterHandlerFunc func([]string, dns.Handler, int)
|
RegisterHandlerFunc func(domain.List, dns.Handler, int)
|
||||||
DeregisterHandlerFunc func([]string, int)
|
DeregisterHandlerFunc func(domain.List, int)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||||
if m.RegisterHandlerFunc != nil {
|
if m.RegisterHandlerFunc != nil {
|
||||||
m.RegisterHandlerFunc(domains, handler, priority)
|
m.RegisterHandlerFunc(domains, handler, priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
func (m *MockServer) DeregisterHandler(domains domain.List, priority int) {
|
||||||
if m.DeregisterHandlerFunc != nil {
|
if m.DeregisterHandlerFunc != nil {
|
||||||
m.DeregisterHandlerFunc(domains, priority)
|
m.DeregisterHandlerFunc(domains, priority)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
|
matchDomains = append(matchDomains, "~."+dConf.Domain)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
|
searchDomains = append(searchDomains, dConf.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
@@ -18,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
@@ -32,8 +35,8 @@ type IosDnsManager interface {
|
|||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||||
DeregisterHandler(domains []string, priority int)
|
DeregisterHandler(domains domain.List, priority int)
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
@@ -65,6 +68,7 @@ type DefaultServer struct {
|
|||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
handlerChain *HandlerChain
|
handlerChain *HandlerChain
|
||||||
|
extraDomains map[domain.Domain]int
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
@@ -164,13 +168,15 @@ func newDefaultServer(
|
|||||||
stateManager *statemanager.Manager,
|
stateManager *statemanager.Manager,
|
||||||
disableSys bool,
|
disableSys bool,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
|
handlerChain := NewHandlerChain()
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
disableSys: disableSys,
|
disableSys: disableSys,
|
||||||
service: dnsService,
|
service: dnsService,
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: handlerChain,
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
@@ -181,14 +187,26 @@ func newDefaultServer(
|
|||||||
hostsDNSHolder: newHostsDNSHolder(),
|
hostsDNSHolder: newHostsDNSHolder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// register with root zone, handler chain takes care of the routing
|
||||||
|
dnsService.RegisterMux(".", handlerChain)
|
||||||
|
|
||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||||
|
// Any previously registered handler for the same domain and priority will be replaced.
|
||||||
|
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
s.registerHandler(domains, handler, priority)
|
s.registerHandler(domains.ToPunycodeList(), handler, priority)
|
||||||
|
|
||||||
|
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
|
||||||
|
for _, domain := range domains {
|
||||||
|
// convert to zone with simple ref counter
|
||||||
|
s.extraDomains[toZone(domain)]++
|
||||||
|
}
|
||||||
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
@@ -200,15 +218,23 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handlerChain.AddHandler(domain, handler, priority)
|
s.handlerChain.AddHandler(domain, handler, priority)
|
||||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
// DeregisterHandler deregisters the handler for the given domains with the given priority.
|
||||||
|
func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
s.deregisterHandler(domains, priority)
|
s.deregisterHandler(domains.ToPunycodeList(), priority)
|
||||||
|
for _, domain := range domains {
|
||||||
|
zone := toZone(domain)
|
||||||
|
s.extraDomains[zone]--
|
||||||
|
if s.extraDomains[zone] <= 0 {
|
||||||
|
delete(s.extraDomains, zone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||||
@@ -221,11 +247,6 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.handlerChain.RemoveHandler(domain, priority)
|
s.handlerChain.RemoveHandler(domain, priority)
|
||||||
|
|
||||||
// Only deregister from service if no handlers remain
|
|
||||||
if !s.handlerChain.HasHandlers(domain) {
|
|
||||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -286,6 +307,8 @@ func (s *DefaultServer) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
|
|
||||||
|
maps.Clear(s.extraDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
@@ -390,7 +413,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
// is the service should be Disabled, we stop the listener or fake resolver
|
// is the service should be Disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
if update.ServiceEnable {
|
if update.ServiceEnable {
|
||||||
_ = s.service.Listen()
|
if err := s.service.Listen(); err != nil {
|
||||||
|
log.Errorf("failed to start DNS service: %v", err)
|
||||||
|
}
|
||||||
} else if !s.permanent {
|
} else if !s.permanent {
|
||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
}
|
}
|
||||||
@@ -413,17 +438,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
hostUpdate := s.currentConfig
|
|
||||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||||
hostUpdate.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
s.applyHostConfig()
|
||||||
log.Error(err)
|
|
||||||
s.handleErrNoGroupaAll(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// persist dns state right away
|
// persist dns state right away
|
||||||
@@ -441,6 +462,43 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) applyHostConfig() {
|
||||||
|
if s.hostManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// prevent reapplying config if we're shutting down
|
||||||
|
if s.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
config := s.currentConfig
|
||||||
|
|
||||||
|
existingDomains := make(map[string]struct{})
|
||||||
|
for _, d := range config.Domains {
|
||||||
|
existingDomains[d.Domain] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add extra domains only if they're not already in the config
|
||||||
|
for domain := range s.extraDomains {
|
||||||
|
domainStr := domain.PunycodeString()
|
||||||
|
|
||||||
|
if _, exists := existingDomains[domainStr]; !exists {
|
||||||
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
|
Domain: domainStr,
|
||||||
|
MatchOnly: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("extra match domains: %v", s.extraDomains)
|
||||||
|
|
||||||
|
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||||
|
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||||
|
s.handleErrNoGroupaAll(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
||||||
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
||||||
return
|
return
|
||||||
@@ -690,10 +748,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
s.applyHostConfig()
|
||||||
s.handleErrNoGroupaAll(err)
|
|
||||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
@@ -728,12 +783,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.hostManager != nil {
|
s.applyHostConfig()
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
|
||||||
s.handleErrNoGroupaAll(err)
|
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.updateNSState(nsGroup, nil, true)
|
s.updateNSState(nsGroup, nil, true)
|
||||||
}
|
}
|
||||||
@@ -836,3 +886,13 @@ func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toZone(d domain.Domain) domain.Domain {
|
||||||
|
return domain.Domain(
|
||||||
|
nbdns.NormalizeZone(
|
||||||
|
dns.Fqdn(
|
||||||
|
strings.ToLower(d.PunycodeString()),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,19 +23,23 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type mocWGIface struct {
|
type mocWGIface struct {
|
||||||
filter device.PacketFilter
|
filter device.PacketFilter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Name() string {
|
func (w *mocWGIface) Name() string {
|
||||||
panic("implement me")
|
return "utun2301"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Address() wgaddr.Address {
|
func (w *mocWGIface) Address() wgaddr.Address {
|
||||||
@@ -456,7 +460,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||||
@@ -917,7 +921,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface, false)
|
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1445,3 +1449,497 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtraDomains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
initialConfig nbdns.Config
|
||||||
|
registerDomains []domain.List
|
||||||
|
deregisterDomains []domain.List
|
||||||
|
finalConfig nbdns.Config
|
||||||
|
expectedDomains []string
|
||||||
|
expectedMatchOnly []string
|
||||||
|
applyHostConfigCall int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Register domains before config update",
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
|
},
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register domains after config update",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register overlapping domains",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "overlap.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "overlap.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"overlap.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register and deregister domains",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
|
{"extra3.example.com", "extra4.example.com"},
|
||||||
|
},
|
||||||
|
deregisterDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra3.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
"extra4.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra2.example.com.",
|
||||||
|
"extra4.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register domains with ref counter",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "duplicate.example.com"},
|
||||||
|
{"other.example.com", "duplicate.example.com"},
|
||||||
|
},
|
||||||
|
deregisterDomains: []domain.List{
|
||||||
|
{"duplicate.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
"other.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
"other.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Config update with new domains after registration",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "duplicate.example.com"},
|
||||||
|
},
|
||||||
|
finalConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "newconfig.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"newconfig.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Deregister domain that is part of customZones",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "protected.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "protected.example.com"},
|
||||||
|
},
|
||||||
|
deregisterDomains: []domain.List{
|
||||||
|
{"protected.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
"config.example.com.",
|
||||||
|
"protected.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register domain that is part of nameserver group",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "overlap.ns.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"ns.example.com.",
|
||||||
|
"overlap.ns.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"ns.example.com.",
|
||||||
|
"overlap.ns.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var capturedConfigs []HostDNSConfig
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
capturedConfigs = append(capturedConfigs, config)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
wgInterface: &mocWGIface{},
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply initial configuration
|
||||||
|
if tt.initialConfig.ServiceEnable {
|
||||||
|
err := server.applyConfiguration(tt.initialConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register domains
|
||||||
|
for _, domains := range tt.registerDomains {
|
||||||
|
server.RegisterHandler(domains, &MockHandler{}, PriorityDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister domains if specified
|
||||||
|
for _, domains := range tt.deregisterDomains {
|
||||||
|
server.DeregisterHandler(domains, PriorityDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply final configuration if specified
|
||||||
|
if tt.finalConfig.ServiceEnable {
|
||||||
|
err := server.applyConfiguration(tt.finalConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify number of calls
|
||||||
|
assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs),
|
||||||
|
"Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs))
|
||||||
|
|
||||||
|
// Get the last applied config
|
||||||
|
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
||||||
|
|
||||||
|
// Check all expected domains are present
|
||||||
|
domainMap := make(map[string]bool)
|
||||||
|
matchOnlyMap := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, d := range lastConfig.Domains {
|
||||||
|
domainMap[d.Domain] = true
|
||||||
|
if d.MatchOnly {
|
||||||
|
matchOnlyMap[d.Domain] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected domains
|
||||||
|
for _, d := range tt.expectedDomains {
|
||||||
|
assert.True(t, domainMap[d], "Expected domain %s not found in final config", d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify match-only domains
|
||||||
|
for _, d := range tt.expectedMatchOnly {
|
||||||
|
assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no unexpected domains
|
||||||
|
assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config")
|
||||||
|
assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtraDomainsRefCounting(t *testing.T) {
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register domains from different handlers with same domain
|
||||||
|
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
||||||
|
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
||||||
|
|
||||||
|
// Verify refcount is 2
|
||||||
|
zoneKey := toZone("shared.example.com")
|
||||||
|
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
||||||
|
|
||||||
|
// Deregister one handler
|
||||||
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
||||||
|
|
||||||
|
// Verify refcount is 1
|
||||||
|
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
||||||
|
|
||||||
|
// Deregister the other handler
|
||||||
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute)
|
||||||
|
|
||||||
|
// Verify domain is removed
|
||||||
|
_, exists := server.extraDomains[zoneKey]
|
||||||
|
assert.False(t, exists, "Domain should be removed after deregistering all handlers")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||||
|
var capturedConfig HostDNSConfig
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
capturedConfig = config
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault)
|
||||||
|
|
||||||
|
initialConfig := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := server.applyConfiguration(initialConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
}
|
||||||
|
assert.Contains(t, domains, "config.example.com.")
|
||||||
|
assert.Contains(t, domains, "extra.example.com.")
|
||||||
|
|
||||||
|
// Now apply a new configuration with overlapping domain
|
||||||
|
updatedConfig := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "extra.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = server.applyConfiguration(updatedConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify both domains are in config, but no duplicates
|
||||||
|
domains = []string{}
|
||||||
|
matchOnlyCount := 0
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
if d.MatchOnly {
|
||||||
|
matchOnlyCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Contains(t, domains, "config.example.com.")
|
||||||
|
assert.Contains(t, domains, "extra.example.com.")
|
||||||
|
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
||||||
|
|
||||||
|
// Extra domain should no longer be marked as match-only when in config
|
||||||
|
matchOnlyDomain := ""
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
if d.Domain == "extra.example.com." && d.MatchOnly {
|
||||||
|
matchOnlyDomain = d.Domain
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainCaseHandling(t *testing.T) {
|
||||||
|
var capturedConfig HostDNSConfig
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
capturedConfig = config
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
||||||
|
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
||||||
|
|
||||||
|
config := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := server.applyConfiguration(config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
}
|
||||||
|
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||||
|
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
ip, err := netip.ParseAddr(s.runtimeIP)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse runtime ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
@@ -38,7 +37,6 @@ const (
|
|||||||
|
|
||||||
type systemdDbusConfigurator struct {
|
type systemdDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
routingAll bool
|
|
||||||
ifaceName string
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,7 +110,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||||
Domain: dns.Fqdn(dConf.Domain),
|
Domain: dConf.Domain,
|
||||||
MatchOnly: dConf.MatchOnly,
|
MatchOnly: dConf.MatchOnly,
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -124,18 +122,19 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
|
||||||
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
|
return fmt.Errorf("set link as default dns router: %w", err)
|
||||||
}
|
}
|
||||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||||
Domain: nbdns.RootZone,
|
Domain: nbdns.RootZone,
|
||||||
MatchOnly: true,
|
MatchOnly: true,
|
||||||
})
|
})
|
||||||
s.routingAll = true
|
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||||
} else if s.routingAll {
|
} else {
|
||||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil {
|
||||||
|
return fmt.Errorf("remove link as default dns router: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
state := &ShutdownState{
|
state := &ShutdownState{
|
||||||
@@ -151,6 +150,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +167,8 @@ func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdD
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
||||||
}
|
}
|
||||||
return s.flushCaches()
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||||
@@ -183,10 +188,14 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.flushCaches()
|
if err := s.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) flushCaches() error {
|
func (s *systemdDbusConfigurator) flushDNSCache() error {
|
||||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
||||||
|
|||||||
@@ -18,14 +18,16 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
UpstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
failsTillDeact = int32(5)
|
failsTillDeact = int32(5)
|
||||||
reactivatePeriod = 30 * time.Second
|
reactivatePeriod = 30 * time.Second
|
||||||
upstreamTimeout = 15 * time.Second
|
|
||||||
probeTimeout = 2 * time.Second
|
probeTimeout = 2 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,7 +68,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
domain: domain,
|
domain: domain,
|
||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
@@ -106,9 +108,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.SetEdns0(4096, false)
|
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,3 +336,51 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati
|
|||||||
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||||
|
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||||
|
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||||
|
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||||
|
// MTU - ip + udp headers
|
||||||
|
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower.
|
||||||
|
client.UDPSize = iface.DefaultMTU - (60 + 8)
|
||||||
|
|
||||||
|
var (
|
||||||
|
rm *dns.Msg
|
||||||
|
t time.Duration
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctx == nil {
|
||||||
|
rm, t, err = client.Exchange(r, upstream)
|
||||||
|
} else {
|
||||||
|
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rm == nil || !rm.MsgHdr.Truncated {
|
||||||
|
return rm, t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||||
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
|
client.Net = "tcp"
|
||||||
|
|
||||||
|
if ctx == nil {
|
||||||
|
rm, t, err = client.Exchange(r, upstream)
|
||||||
|
} else {
|
||||||
|
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||||
|
|
||||||
|
return rm, t, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
|||||||
|
|
||||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||||
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
timeout := upstreamTimeout
|
timeout := UpstreamTimeout
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
timeout = time.Until(deadline)
|
timeout = time.Until(deadline)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,6 +34,5 @@ func newUpstreamResolver(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
upstreamExchangeClient := &dns.Client{}
|
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
||||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := upstreamTimeout
|
timeout := UpstreamTimeout
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
timeout = time.Until(deadline)
|
timeout = time.Until(deadline)
|
||||||
}
|
}
|
||||||
@@ -68,7 +68,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||||
return client.Exchange(r, upstream)
|
return ExchangeWithFallback(nil, client, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user