mirror of
https://github.com/netbirdio/netbird.git
synced 2026-04-19 16:56:39 +00:00
Compare commits
90 Commits
test/netwo
...
feature/up
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
06c5564dd2 | ||
|
|
cace564a19 | ||
|
|
d07c83c111 | ||
|
|
b0a5696bec | ||
|
|
dc24b9e276 | ||
|
|
312bfd9bd7 | ||
|
|
8db05838ca | ||
|
|
c69df13515 | ||
|
|
986eb8c1e0 | ||
|
|
197761ba4d | ||
|
|
f74ea64c7b | ||
|
|
3b7b9d25bc | ||
|
|
1a6d6b3109 | ||
|
|
f686615876 | ||
|
|
a4311f574d | ||
|
|
0bb8eae903 | ||
|
|
e0b33d325d | ||
|
|
c38e07d89a | ||
|
|
a37368fff4 | ||
|
|
0c93bd3d06 | ||
|
|
a675531b5c | ||
|
|
7cb366bc7d | ||
|
|
a354004564 | ||
|
|
75bdd47dfb | ||
|
|
b165f63327 | ||
|
|
51bb52cdf5 | ||
|
|
4134b857b4 | ||
|
|
7839d2c169 | ||
|
|
b9f82e2f8a | ||
|
|
fd2a21c65d | ||
|
|
82d982b0ab | ||
|
|
9e24fe7701 | ||
|
|
e470701b80 | ||
|
|
e3ce026355 | ||
|
|
5ea2806663 | ||
|
|
d6b0673580 | ||
|
|
14913cfa7a | ||
|
|
03f600b576 | ||
|
|
192c97aa63 | ||
|
|
4db78db49a | ||
|
|
87e600a4f3 | ||
|
|
6162aeb82d | ||
|
|
1ba1e092ce | ||
|
|
86dbb4ee4f | ||
|
|
4af177215f | ||
|
|
df9c1b9883 | ||
|
|
5752bb78f2 | ||
|
|
fbd783ad58 | ||
|
|
80702b9323 | ||
|
|
09243a0fe0 | ||
|
|
3658215747 | ||
|
|
48ffec95dd | ||
|
|
cbec7bda80 | ||
|
|
21464ac770 | ||
|
|
ed5647028a | ||
|
|
29a6e5be71 | ||
|
|
6124e3b937 | ||
|
|
50f5cc48cd | ||
|
|
101cce27f2 | ||
|
|
a4f04f5570 | ||
|
|
fceb3ca392 | ||
|
|
34d86c5ab8 | ||
|
|
9cbcf7531f | ||
|
|
bd8f0c1ef3 | ||
|
|
051a5a4adc | ||
|
|
8b4c0c58e4 | ||
|
|
99b41543b8 | ||
|
|
2bbe0f3f09 | ||
|
|
9325fb7990 | ||
|
|
f081435a56 | ||
|
|
b62a1b56ce | ||
|
|
8d7c92c661 | ||
|
|
d9d051cb1e | ||
|
|
cb318b7ef4 | ||
|
|
8f0aa8352a | ||
|
|
c02e236196 | ||
|
|
f51e0b59bd | ||
|
|
32ec42a667 | ||
|
|
9929daf6ce | ||
|
|
939419a0ea | ||
|
|
919fe94fd5 | ||
|
|
df71cb4690 | ||
|
|
4508c61728 | ||
|
|
0ef476b014 | ||
|
|
6f82e96d6a | ||
|
|
a2faae5d62 | ||
|
|
4a3cbcd38a | ||
|
|
c2980bc8cf | ||
|
|
67ae871ce4 | ||
|
|
39ff5e833a |
27
.git-branches.toml
Normal file
27
.git-branches.toml
Normal file
@@ -0,0 +1,27 @@
|
|||||||
|
# More info around this file at https://www.git-town.com/configuration-file
|
||||||
|
|
||||||
|
[branches]
|
||||||
|
main = "main"
|
||||||
|
perennials = []
|
||||||
|
perennial-regex = ""
|
||||||
|
|
||||||
|
[create]
|
||||||
|
new-branch-type = "feature"
|
||||||
|
push-new-branches = false
|
||||||
|
|
||||||
|
[hosting]
|
||||||
|
dev-remote = "origin"
|
||||||
|
# platform = ""
|
||||||
|
# origin-hostname = ""
|
||||||
|
|
||||||
|
[ship]
|
||||||
|
delete-tracking-branch = false
|
||||||
|
strategy = "squash-merge"
|
||||||
|
|
||||||
|
[sync]
|
||||||
|
feature-strategy = "merge"
|
||||||
|
perennial-strategy = "rebase"
|
||||||
|
prototype-strategy = "merge"
|
||||||
|
push-hook = true
|
||||||
|
tags = true
|
||||||
|
upstream = false
|
||||||
4
.github/pull_request_template.md
vendored
4
.github/pull_request_template.md
vendored
@@ -2,6 +2,10 @@
|
|||||||
|
|
||||||
## Issue ticket number and link
|
## Issue ticket number and link
|
||||||
|
|
||||||
|
## Stack
|
||||||
|
|
||||||
|
<!-- branch-stack -->
|
||||||
|
|
||||||
### Checklist
|
### Checklist
|
||||||
- [ ] Is it a bug fix
|
- [ ] Is it a bug fix
|
||||||
- [ ] Is a typo/documentation fix
|
- [ ] Is a typo/documentation fix
|
||||||
|
|||||||
21
.github/workflows/git-town.yml
vendored
Normal file
21
.github/workflows/git-town.yml
vendored
Normal file
@@ -0,0 +1,21 @@
|
|||||||
|
name: Git Town
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
branches:
|
||||||
|
- '**'
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
git-town:
|
||||||
|
name: Display the branch stack
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
|
||||||
|
permissions:
|
||||||
|
contents: read
|
||||||
|
pull-requests: write
|
||||||
|
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@v4
|
||||||
|
- uses: git-town/action@v1
|
||||||
|
with:
|
||||||
|
skip-single-stacks: true
|
||||||
10
.github/workflows/golang-test-freebsd.yml
vendored
10
.github/workflows/golang-test-freebsd.yml
vendored
@@ -22,14 +22,20 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
usesh: true
|
usesh: true
|
||||||
copyback: false
|
copyback: false
|
||||||
release: "14.1"
|
release: "14.2"
|
||||||
prepare: |
|
prepare: |
|
||||||
pkg install -y go pkgconf xorg
|
pkg install -y curl pkgconf xorg
|
||||||
|
LATEST_VERSION=$(curl -s https://go.dev/VERSION?m=text|head -n 1)
|
||||||
|
GO_TARBALL="$LATEST_VERSION.freebsd-amd64.tar.gz"
|
||||||
|
GO_URL="https://go.dev/dl/$GO_TARBALL"
|
||||||
|
curl -vLO "$GO_URL"
|
||||||
|
tar -C /usr/local -vxzf "$GO_TARBALL"
|
||||||
|
|
||||||
# -x - to print all executed commands
|
# -x - to print all executed commands
|
||||||
# -e - to faile on first error
|
# -e - to faile on first error
|
||||||
run: |
|
run: |
|
||||||
set -e -x
|
set -e -x
|
||||||
|
export PATH=$PATH:/usr/local/go/bin:$HOME/go/bin
|
||||||
time go build -o netbird client/main.go
|
time go build -o netbird client/main.go
|
||||||
# check all component except management, since we do not support management server on freebsd
|
# check all component except management, since we do not support management server on freebsd
|
||||||
time go test -timeout 1m -failfast ./base62/...
|
time go test -timeout 1m -failfast ./base62/...
|
||||||
|
|||||||
78
.github/workflows/golang-test-linux.yml
vendored
78
.github/workflows/golang-test-linux.yml
vendored
@@ -314,6 +314,7 @@ jobs:
|
|||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags=devcert \
|
go test -tags=devcert \
|
||||||
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
-exec "sudo --preserve-env=CI,NETBIRD_STORE_ENGINE" \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
@@ -380,7 +381,8 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags devcert -run=^$ -bench=. \
|
go test -tags devcert -run=^$ -bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
-timeout 20m ./...
|
-timeout 20m ./...
|
||||||
@@ -396,6 +398,33 @@ jobs:
|
|||||||
store: [ 'sqlite', 'postgres' ]
|
store: [ 'sqlite', 'postgres' ]
|
||||||
runs-on: ubuntu-22.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
|
- name: Create Docker network
|
||||||
|
run: docker network create promnet
|
||||||
|
|
||||||
|
- name: Start Prometheus Pushgateway
|
||||||
|
run: docker run -d --name pushgateway --network promnet -p 9091:9091 prom/pushgateway
|
||||||
|
|
||||||
|
- name: Start Prometheus (for Pushgateway forwarding)
|
||||||
|
run: |
|
||||||
|
echo '
|
||||||
|
global:
|
||||||
|
scrape_interval: 15s
|
||||||
|
scrape_configs:
|
||||||
|
- job_name: "pushgateway"
|
||||||
|
static_configs:
|
||||||
|
- targets: ["pushgateway:9091"]
|
||||||
|
remote_write:
|
||||||
|
- url: ${{ secrets.GRAFANA_URL }}
|
||||||
|
basic_auth:
|
||||||
|
username: ${{ secrets.GRAFANA_USER }}
|
||||||
|
password: ${{ secrets.GRAFANA_API_KEY }}
|
||||||
|
' > prometheus.yml
|
||||||
|
|
||||||
|
docker run -d --name prometheus --network promnet \
|
||||||
|
-v $PWD/prometheus.yml:/etc/prometheus/prometheus.yml \
|
||||||
|
-p 9090:9090 \
|
||||||
|
prom/prometheus
|
||||||
|
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
with:
|
with:
|
||||||
@@ -447,11 +476,13 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
|
GIT_BRANCH=${{ github.ref_name }} \
|
||||||
go test -tags=benchmark \
|
go test -tags=benchmark \
|
||||||
-run=^$ \
|
-run=^$ \
|
||||||
-bench=. \
|
-bench=. \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE,GIT_BRANCH,GITHUB_RUN_ID' \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
|
|
||||||
api_integration_test:
|
api_integration_test:
|
||||||
@@ -505,7 +536,8 @@ jobs:
|
|||||||
- name: Test
|
- name: Test
|
||||||
run: |
|
run: |
|
||||||
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
CGO_ENABLED=1 GOARCH=${{ matrix.arch }} \
|
||||||
NETBIRD_STORE_ENGINE=${{ matrix.store }} CI=true \
|
NETBIRD_STORE_ENGINE=${{ matrix.store }} \
|
||||||
|
CI=true \
|
||||||
go test -tags=integration \
|
go test -tags=integration \
|
||||||
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
-exec 'sudo --preserve-env=CI,NETBIRD_STORE_ENGINE' \
|
||||||
-timeout 20m ./management/...
|
-timeout 20m ./management/...
|
||||||
@@ -513,7 +545,7 @@ jobs:
|
|||||||
test_client_on_docker:
|
test_client_on_docker:
|
||||||
name: "Client (Docker) / Unit"
|
name: "Client (Docker) / Unit"
|
||||||
needs: [ build-cache ]
|
needs: [ build-cache ]
|
||||||
runs-on: ubuntu-20.04
|
runs-on: ubuntu-22.04
|
||||||
steps:
|
steps:
|
||||||
- name: Install Go
|
- name: Install Go
|
||||||
uses: actions/setup-go@v5
|
uses: actions/setup-go@v5
|
||||||
@@ -527,7 +559,7 @@ jobs:
|
|||||||
- name: Get Go environment
|
- name: Get Go environment
|
||||||
run: |
|
run: |
|
||||||
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
echo "cache=$(go env GOCACHE)" >> $GITHUB_ENV
|
||||||
echo "modcache=$(go env GOMODCACHE)" >> $GITHUB_ENV
|
echo "modcache=$(go env.GOMODCACHE)" >> $GITHUB_ENV
|
||||||
|
|
||||||
- name: Cache Go modules
|
- name: Cache Go modules
|
||||||
uses: actions/cache/restore@v4
|
uses: actions/cache/restore@v4
|
||||||
@@ -545,7 +577,7 @@ jobs:
|
|||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|
||||||
- name: check git status
|
- name: Check git status
|
||||||
run: git --no-pager diff --exit-code
|
run: git --no-pager diff --exit-code
|
||||||
|
|
||||||
- name: Generate Shared Sock Test bin
|
- name: Generate Shared Sock Test bin
|
||||||
@@ -554,8 +586,22 @@ jobs:
|
|||||||
- name: Generate RouteManager Test bin
|
- name: Generate RouteManager Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
run: CGO_ENABLED=0 go test -c -o routemanager-testing.bin ./client/internal/routemanager
|
||||||
|
|
||||||
- name: Generate SystemOps Test bin
|
- name: Generate SystemOps Test bin (static via Alpine)
|
||||||
run: CGO_ENABLED=1 go test -c -o systemops-testing.bin -tags netgo -ldflags '-w -extldflags "-static -ldbus-1 -lpcap"' ./client/internal/routemanager/systemops
|
run: |
|
||||||
|
docker run --rm -v $PWD:/app -w /app \
|
||||||
|
alpine:latest \
|
||||||
|
sh -c "
|
||||||
|
apk add --no-cache go gcc musl-dev libpcap-dev dbus-dev && \
|
||||||
|
adduser -D -u $(id -u) builder && \
|
||||||
|
su builder -c '\
|
||||||
|
cd /app && \
|
||||||
|
CGO_ENABLED=1 GOOS=linux GOARCH=amd64 \
|
||||||
|
go test -c -o /app/systemops-testing.bin \
|
||||||
|
-tags netgo \
|
||||||
|
-ldflags=\"-w -extldflags \\\"-static -ldbus-1 -lpcap\\\"\" \
|
||||||
|
./client/internal/routemanager/systemops \
|
||||||
|
'
|
||||||
|
"
|
||||||
|
|
||||||
- name: Generate nftables Manager Test bin
|
- name: Generate nftables Manager Test bin
|
||||||
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
run: CGO_ENABLED=0 go test -c -o nftablesmanager-testing.bin ./client/firewall/nftables/...
|
||||||
@@ -569,25 +615,25 @@ jobs:
|
|||||||
- run: chmod +x *testing.bin
|
- run: chmod +x *testing.bin
|
||||||
|
|
||||||
- name: Run Shared Sock tests in docker
|
- name: Run Shared Sock tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/sharedsock-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/sharedsock --entrypoint /ci/sharedsock-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Iface tests in docker
|
- name: Run Iface tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/netbird -v /tmp/cache:/tmp/cache -v /tmp/modcache:/tmp/modcache -w /netbird -e GOCACHE=/tmp/cache -e GOMODCACHE=/tmp/modcache -e CGO_ENABLED=0 golang:1.23-alpine go test -test.timeout 5m -test.parallel 1 ./client/iface/...
|
||||||
|
|
||||||
- name: Run RouteManager tests in docker
|
- name: Run RouteManager tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/routemanager-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager --entrypoint /ci/routemanager-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run SystemOps tests in docker
|
- name: Run SystemOps tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/systemops-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/routemanager/systemops --entrypoint /ci/systemops-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run nftables Manager tests in docker
|
- name: Run nftables Manager tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/nftablesmanager-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/firewall --entrypoint /ci/nftablesmanager-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Engine tests in docker with file store
|
- name: Run Engine tests in docker with file store
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="jsonfile" --entrypoint /ci/engine-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Engine tests in docker with sqlite store
|
- name: Run Engine tests in docker with sqlite store
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/engine-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal -e NETBIRD_STORE_ENGINE="sqlite" --entrypoint /ci/engine-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1
|
||||||
|
|
||||||
- name: Run Peer tests in docker
|
- name: Run Peer tests in docker
|
||||||
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /busybox/sh gcr.io/distroless/base:debug -c /ci/peer-testing.bin -test.timeout 5m -test.parallel 1
|
run: docker run -t --cap-add=NET_ADMIN --privileged --rm -v $PWD:/ci -w /ci/client/internal/peer --entrypoint /ci/peer-testing.bin gcr.io/distroless/base:debug -test.timeout 5m -test.parallel 1
|
||||||
|
|||||||
2
.github/workflows/golangci-lint.yml
vendored
2
.github/workflows/golangci-lint.yml
vendored
@@ -19,7 +19,7 @@ jobs:
|
|||||||
- name: codespell
|
- name: codespell
|
||||||
uses: codespell-project/actions-codespell@v2
|
uses: codespell-project/actions-codespell@v2
|
||||||
with:
|
with:
|
||||||
ignore_words_list: erro,clienta,hastable,iif,groupd,testin
|
ignore_words_list: erro,clienta,hastable,iif,groupd,testin,groupe
|
||||||
skip: go.mod,go.sum
|
skip: go.mod,go.sum
|
||||||
only_warn: 1
|
only_warn: 1
|
||||||
golangci:
|
golangci:
|
||||||
|
|||||||
8
.github/workflows/release.yml
vendored
8
.github/workflows/release.yml
vendored
@@ -87,25 +87,25 @@ jobs:
|
|||||||
with:
|
with:
|
||||||
name: release
|
name: release
|
||||||
path: dist/
|
path: dist/
|
||||||
retention-days: 3
|
retention-days: 7
|
||||||
- name: upload linux packages
|
- name: upload linux packages
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: linux-packages
|
name: linux-packages
|
||||||
path: dist/netbird_linux**
|
path: dist/netbird_linux**
|
||||||
retention-days: 3
|
retention-days: 7
|
||||||
- name: upload windows packages
|
- name: upload windows packages
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: windows-packages
|
name: windows-packages
|
||||||
path: dist/netbird_windows**
|
path: dist/netbird_windows**
|
||||||
retention-days: 3
|
retention-days: 7
|
||||||
- name: upload macos packages
|
- name: upload macos packages
|
||||||
uses: actions/upload-artifact@v4
|
uses: actions/upload-artifact@v4
|
||||||
with:
|
with:
|
||||||
name: macos-packages
|
name: macos-packages
|
||||||
path: dist/netbird_darwin**
|
path: dist/netbird_darwin**
|
||||||
retention-days: 3
|
retention-days: 7
|
||||||
|
|
||||||
release_ui:
|
release_ui:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
|||||||
@@ -178,6 +178,7 @@ jobs:
|
|||||||
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
grep -A 10 'relay:' docker-compose.yml | egrep 'NB_AUTH_SECRET=.+$'
|
||||||
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
grep -A 7 Relay management.json | grep "rel://$CI_NETBIRD_DOMAIN:33445"
|
||||||
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
grep -A 7 Relay management.json | egrep '"Secret": ".+"'
|
||||||
|
grep DisablePromptLogin management.json | grep 'true'
|
||||||
|
|
||||||
- name: Install modules
|
- name: Install modules
|
||||||
run: go mod tidy
|
run: go mod tidy
|
||||||
|
|||||||
@@ -1,148 +1,64 @@
|
|||||||
# Contributor License Agreement
|
## Contributor License Agreement
|
||||||
|
|
||||||
We are incredibly thankful for the contributions we receive from the community.
|
This Contributor License Agreement (referred to as the "Agreement") is entered into by the individual
|
||||||
We require our external contributors to sign a Contributor License Agreement ("CLA") in
|
submitting this Agreement and NetBird GmbH, c/o Max-Beer-Straße 2-4 Münzstraße 12 10178 Berlin, Germany,
|
||||||
order to ensure that our projects remain licensed under Free and Open Source licenses such
|
referred to as "NetBird" (collectively, the "Parties"). The Agreement outlines the terms and conditions
|
||||||
as BSD-3 while allowing NetBird to build a sustainable business.
|
under which NetBird may utilize software contributions provided by the Contributor for inclusion in
|
||||||
|
its software development projects. By submitting this Agreement, the Contributor confirms their acceptance
|
||||||
NetBird is committed to having a true Open Source Software ("OSS") license for
|
of the terms and conditions outlined below. The Contributor further represents that they are authorized to
|
||||||
our software. A CLA enables NetBird to safely commercialize our products
|
complete this process as described herein.
|
||||||
while keeping a standard OSS license with all the rights that license grants to users: the
|
|
||||||
ability to use the project in their own projects or businesses, to republish modified
|
|
||||||
source, or to completely fork the project.
|
|
||||||
|
|
||||||
This page gives a human-friendly summary of our CLA, details on why we require a CLA, how
|
|
||||||
contributors can sign our CLA, and more. You may view the full legal CLA document (below).
|
|
||||||
|
|
||||||
# Human-friendly summary
|
|
||||||
|
|
||||||
This is a human-readable summary of (and not a substitute for) the full agreement (below).
|
|
||||||
This highlights only some of key terms of the CLA. It has no legal value and you should
|
|
||||||
carefully review all the terms of the actual CLA before agreeing.
|
|
||||||
|
|
||||||
<li>Grant of copyright license. You give NetBird permission to use your copyrighted work
|
|
||||||
in commercial products.
|
|
||||||
</li>
|
|
||||||
|
|
||||||
<li>Grant of patent license. If your contributed work uses a patent, you give NetBird a
|
|
||||||
license to use that patent including within commercial products. You also agree that you
|
|
||||||
have permission to grant this license.
|
|
||||||
</li>
|
|
||||||
|
|
||||||
<li>No Warranty or Support Obligations.
|
|
||||||
By making a contribution, you are not obligating yourself to provide support for the
|
|
||||||
contribution, and you are not taking on any warranty obligations or providing any
|
|
||||||
assurances about how it will perform.
|
|
||||||
</li>
|
|
||||||
|
|
||||||
The CLA does not change the terms of the standard open source license used by our software
|
|
||||||
such as BSD-3 or MIT.
|
|
||||||
You are still free to use our projects within your own projects or businesses, republish
|
|
||||||
modified source, and more.
|
|
||||||
Please reference the appropriate license for the project you're contributing to to learn
|
|
||||||
more.
|
|
||||||
|
|
||||||
# Why require a CLA?
|
|
||||||
|
|
||||||
Agreeing to a CLA explicitly states that you are entitled to provide a contribution, that you cannot withdraw permission
|
|
||||||
to use your contribution at a later date, and that NetBird has permission to use your contribution in our commercial
|
|
||||||
products.
|
|
||||||
|
|
||||||
This removes any ambiguities or uncertainties caused by not having a CLA and allows users and customers to confidently
|
|
||||||
adopt our projects. At the same time, the CLA ensures that all contributions to our open source projects are licensed
|
|
||||||
under the project's respective open source license, such as BSD-3.
|
|
||||||
|
|
||||||
Requiring a CLA is a common and well-accepted practice in open source. Major open source projects require CLAs such as
|
|
||||||
Apache Software Foundation projects, Facebook projects (such as React), Google projects (including Go), Python, Django,
|
|
||||||
and more. Each of these projects remains licensed under permissive OSS licenses such as MIT, Apache, BSD, and more.
|
|
||||||
|
|
||||||
# Signing the CLA
|
|
||||||
|
|
||||||
Open a pull request ("PR") to any of our open source projects to sign the CLA. A bot will comment on the PR asking you
|
|
||||||
to sign the CLA if you haven't already.
|
|
||||||
|
|
||||||
Follow the steps given by the bot to sign the CLA. This will require you to log in with GitHub (we only request public
|
|
||||||
information from your account) and to fill in a few additional details such as your name and email address. We will only
|
|
||||||
use this information for CLA tracking; none of your submitted information will be used for marketing purposes.
|
|
||||||
|
|
||||||
You only have to sign the CLA once. Once you've signed the CLA, future contributions to any NetBird project will not
|
|
||||||
require you to sign again.
|
|
||||||
|
|
||||||
# Legal Terms and Agreement
|
|
||||||
|
|
||||||
In order to clarify the intellectual property license granted with Contributions from any person or entity, NetBird
|
|
||||||
GmbH ("NetBird") must have a Contributor License Agreement ("CLA") on file that has been signed
|
|
||||||
by each Contributor, indicating agreement to the license terms below. This license does not change your rights to use
|
|
||||||
your own Contributions for any other purpose.
|
|
||||||
|
|
||||||
You accept and agree to the following terms and conditions for Your present and future Contributions submitted to
|
|
||||||
NetBird. Except for the license granted herein to NetBird and recipients of software distributed by NetBird,
|
|
||||||
You reserve all right, title, and interest in and to Your Contributions.
|
|
||||||
|
|
||||||
1. Definitions.
|
|
||||||
|
|
||||||
```
|
|
||||||
"You" (or "Your") shall mean the copyright owner or legal entity authorized by the copyright owner
|
|
||||||
that is making this Agreement with NetBird. For legal entities, the entity making a Contribution and all other
|
|
||||||
entities that control, are controlled by, or are under common control with that entity are considered
|
|
||||||
to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect,
|
|
||||||
to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty
|
|
||||||
percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
|
||||||
```
|
|
||||||
```
|
|
||||||
"Contribution" shall mean any original work of authorship, including any modifications or additions to
|
|
||||||
an existing work, that is or previously has been intentionally submitted by You to NetBird for inclusion in,
|
|
||||||
or documentation of, any of the products owned or managed by NetBird (the "Work").
|
|
||||||
For the purposes of this definition, "submitted" means any form of electronic, verbal, or written communication
|
|
||||||
sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists,
|
|
||||||
source code control systems, and issue tracking systems that are managed by, or on behalf of,
|
|
||||||
NetBird for the purpose of discussing and improving the Work, but excluding communication that is conspicuously
|
|
||||||
marked or otherwise designated in writing by You as "Not a Contribution."
|
|
||||||
```
|
|
||||||
|
|
||||||
2. Grant of Copyright License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird
|
|
||||||
and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge,
|
|
||||||
royalty-free, irrevocable copyright license to reproduce, prepare derivative works of, publicly display, publicly
|
|
||||||
perform, sublicense, and distribute Your Contributions and such derivative works.
|
|
||||||
|
|
||||||
|
|
||||||
3. Grant of Patent License. Subject to the terms and conditions of this Agreement, You hereby grant to NetBird and
|
## 1 Preamble
|
||||||
to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free,
|
In order to clarify the IP Rights situation with regard to Contributions from any person or entity, NetBird
|
||||||
irrevocable (except as stated in this section) patent license to make, have made, use, offer to sell, sell, import,
|
must have a contributor license agreement on file to be signed by each Contributor, containing the license
|
||||||
and otherwise transfer the Work, where such license applies only to those patent claims licensable by You that are
|
terms below. This license serves as protection for both the Contributor as well as NetBird and its software users;
|
||||||
necessarily infringed by Your Contribution(s) alone or by combination of Your Contribution(s) with the Work to which
|
it does not change Contributor’s rights to use his/her own Contributions for any other purpose.
|
||||||
such Contribution(s) was submitted. If any entity institutes patent litigation against You or any other entity (
|
|
||||||
including a cross-claim or counterclaim in a lawsuit) alleging that your Contribution, or the Work to which you have
|
|
||||||
contributed, constitutes direct or contributory patent infringement, then any patent licenses granted to that entity
|
|
||||||
under this Agreement for that Contribution or Work shall terminate as of the date such litigation is filed.
|
|
||||||
|
|
||||||
|
## 2 Definitions
|
||||||
|
2.1 “IP Rights” shall mean all industrial and intellectual property rights, whether registered or not registered, whether created by Contributor or acquired by Contributor from third parties, and similar rights, including (but not limited to) semiconductor property rights, design rights, copyrights (including in the form of database rights and rights to software), all neighbouring rights (Leistungsschutzrechte), trademarks, service marks, titles, internet domain names, trade names and other labelling rights, rights deriving from corresponding applications and registrations of such rights as well as any licenses (Nutzungsrechte) under and entitlements to any such intellectual and industrial property rights.
|
||||||
|
|
||||||
4. You represent that you are legally entitled to grant the above license. If your employer(s) has rights to
|
2.2 "Contribution" shall mean any original work of authorship, including any modifications or additions to an existing work, that is or previously has been intentionally Submitted by Contributor to NetBird for inclusion in, or documentation of any Work.
|
||||||
intellectual property that you create that includes your Contributions, you represent that you have received
|
|
||||||
permission to make Contributions on behalf of that employer, that you will have received permission from your current
|
|
||||||
and future employers for all future Contributions, that your applicable employer has waived such rights for all of
|
|
||||||
your current and future Contributions to NetBird, or that your employer has executed a separate Corporate CLA
|
|
||||||
with NetBird.
|
|
||||||
|
|
||||||
|
2.3 "Contributor" shall mean the copyright owner or legal entity authorized by the copyright owner that is concluding this Agreement with NetBird. For legal entities, the entity making a Contribution and all other entities that control, are controlled by, or are under common control with that entity are considered to be a single Contributor. For the purposes of this definition, "control" means (i) the power, direct or indirect, to cause the direction or management of such entity, whether by contract or otherwise, or (ii) ownership of fifty percent (50%) or more of the outstanding shares, or (iii) beneficial ownership of such entity.
|
||||||
|
|
||||||
5. You represent that each of Your Contributions is Your original creation (see section 7 for submissions on behalf of
|
2.4 "Submitted" shall mean any form of electronic, verbal, or written communication sent to NetBird or its representatives, including but not limited to communication on electronic mailing lists, source code control systems, and issue tracking systems that are managed by, or on behalf of, NetBird for the purpose of discussing and improving the Work, but excluding communication that is marked or otherwise designated in writing by Contributor as "Not a Contribution".
|
||||||
others). You represent that Your Contribution submissions include complete details of any third-party license or
|
|
||||||
other restriction (including, but not limited to, related patents and trademarks) of which you are personally aware
|
|
||||||
and which are associated with any part of Your Contributions.
|
|
||||||
|
|
||||||
|
2.5 "Work" means any of the products owned or managed by NetBird, in particular, but not exclusively, software.
|
||||||
|
|
||||||
6. You are not expected to provide support for Your Contributions, except to the extent You desire to provide support.
|
## 3 Licenses
|
||||||
You may provide support for free, for a fee, or not at all. Unless required by applicable law or agreed to in
|
3.1 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable license to reproduce by any means and in any form, in whole or in part, permanently or temporarily, the Contributions (including loading, displaying, executing, transmitting or storing works for the purpose of executing and processing data or transferring them to video, audio and other data carriers), including the right to distribute, display and present such Contributions and make them available to the public (e.g. via the internet) and to transmit and display such Contributions by any means. The license also includes the right to modify, translate, adapt, edit and otherwise alter the Contributions and to use these results in the same manner as the original Contributions and derivative works. Except for licenses in patents acc. to Sec. 3, such license refers to any IP Rights in the Contributions and derivative works. The Contributor acknowledges that NetBird is not required to credit them by name for their Contribution and agrees to waive any moral rights associated with their Contribution in relation to NetBird or its sublicensees.
|
||||||
writing, You provide Your Contributions on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either
|
|
||||||
express or implied, including, without limitation, any warranties or conditions of TITLE, NON- INFRINGEMENT,
|
|
||||||
MERCHANTABILITY, or FITNESS FOR A PARTICULAR PURPOSE.
|
|
||||||
|
|
||||||
|
3.2 Subject to the terms and conditions of this agreement, Contributor hereby grants to NetBird and to recipients of software distributed by NetBird a perpetual, worldwide, non-exclusive, no-charge, royalty-free, irrevocable (except as stated in this section) patent license in the Contributions to make, have made, use, sell, offer to sell, import, and otherwise transfer the Work, where such license applies only to those patent claims licensable by the Contributor which are necessarily infringed by Contributor‘s Contribution(s) alone or by combination of Contributor’s Contribution(s) with the Work to which such Contribution(s) was Submitted.
|
||||||
|
|
||||||
7. Should You wish to submit work that is not Your original creation, You may submit it to NetBird separately from
|
3.3 NetBird hereby accepts such licenses.
|
||||||
any Contribution, identifying the complete details of its source and of any license or other restriction (including,
|
|
||||||
but not limited to, related patents, trademarks, and license agreements) of which you are personally aware, and
|
|
||||||
conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
|
||||||
|
|
||||||
|
## 4 Contributor’s Representations
|
||||||
|
4.1 Contributor represents that Contributor is legally entitled to grant the above license. If Contributor’s employer has IP Rights to Contributor’s Contributions, Contributor represent that he/she has received permission to make Contributions on behalf of such employer, that such employer has waived such IP Rights to the Contributions of Contributor to NetBird, or that such employer has executed a separate contributor license agreement with NetBird.
|
||||||
|
|
||||||
|
4.2 Contributor represents that any Contribution is his/her original creation.
|
||||||
|
|
||||||
|
4.3 Contributor represents to his/her best knowledge that any Contribution does not violate any third party IP Rights.
|
||||||
|
|
||||||
|
4.4 Contributor represents that any Contribution submission includes complete details of any third-party license or other restriction (including, but not limited to, related patents and trademarks) of which Contributor is personally aware and which are associated with any part of the Contribution.
|
||||||
|
|
||||||
|
4.5 The Contributor represents that their Contribution does not include any work distributed under a copyleft license.
|
||||||
|
|
||||||
|
## 5 Information obligation
|
||||||
|
Contributor agrees to notify NetBird of any facts or circumstances of which Contributor become aware that would make these representations inaccurate in any respect.
|
||||||
|
|
||||||
|
## 6 Submission of Third-Party works
|
||||||
|
Should Contributor wish to submit work that is not Contributor’s original creation, Contributor may submit it to NetBird separately from any Contribution, identifying the complete details of its source and of any license or other restriction (including, but not limited to, related patents, trademarks, and license agreements) of which Contributor are personally aware, and conspicuously marking the work as "Submitted on behalf of a third-party: [named here]".
|
||||||
|
|
||||||
|
## 7 No Consideration
|
||||||
|
Unless compensation is mandatory under statutory law, no compensation for any license under this agreement shall be payable.
|
||||||
|
|
||||||
|
## 8 Final Provisions
|
||||||
|
8.1 Laws. This Agreement is governed by the laws of the Federal Republic of Germany.
|
||||||
|
|
||||||
|
8.2 Venue. Place of jurisdiction shall, to the extent legally permissible, be Berlin, Germany.
|
||||||
|
|
||||||
|
8.3 Severability. If any provision in this agreement is unlawful, invalid or ineffective, it shall not affect the enforceability or effectiveness of the remainder of this agreement. The parties agree to replace any unlawful, invalid or ineffective provision with a provision that comes as close as possible to the commercial intent and purpose of the original provision. This section also applies accordingly to any gaps in the contract.
|
||||||
|
|
||||||
|
8.4 Variations. Any variations, amendments or supplements to this Agreement must be in writing. This also applies to any variation of this Section 8.4.
|
||||||
|
|
||||||
8. You agree to notify NetBird of any facts or circumstances of which you become aware that would make these
|
|
||||||
representations inaccurate in any respect.
|
|
||||||
|
|||||||
28
README.md
28
README.md
@@ -12,7 +12,7 @@
|
|||||||
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
<img src="https://img.shields.io/badge/license-BSD--3-blue" />
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">
|
<a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">
|
||||||
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
<img src="https://img.shields.io/badge/slack-@netbird-red.svg?logo=slack"/>
|
||||||
</a>
|
</a>
|
||||||
<br>
|
<br>
|
||||||
@@ -29,13 +29,13 @@
|
|||||||
<br/>
|
<br/>
|
||||||
See <a href="https://netbird.io/docs/">Documentation</a>
|
See <a href="https://netbird.io/docs/">Documentation</a>
|
||||||
<br/>
|
<br/>
|
||||||
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-2utg2ncdz-W7LEB6toRBLE1Jca37dYpg">Slack channel</a>
|
Join our <a href="https://join.slack.com/t/netbirdio/shared_invite/zt-31rofwmxc-27akKd0Le0vyRpBcwXkP0g">Slack channel</a>
|
||||||
<br/>
|
<br/>
|
||||||
|
|
||||||
</strong>
|
</strong>
|
||||||
<br>
|
<br>
|
||||||
<a href="https://netbird.io/webinars/achieve-zero-trust-access-to-k8s?utm_source=github&utm_campaign=2502%20-%20webinar%20-%20How%20to%20Achieve%20Zero%20Trust%20Access%20to%20Kubernetes%20-%20Effortlessly&utm_medium=github">
|
<a href="https://github.com/netbirdio/kubernetes-operator">
|
||||||
Webinar: Securely Access Kubernetes without Port Forwarding and Jump Hosts
|
New: NetBird Kubernetes Operator
|
||||||
</a>
|
</a>
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
@@ -57,16 +57,16 @@
|
|||||||
|
|
||||||
### Key features
|
### Key features
|
||||||
|
|
||||||
| Connectivity | Management | Security | Automation | Platforms |
|
| Connectivity | Management | Security | Automation| Platforms |
|
||||||
|------------------------------------------------------------------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------|---------------------------------------------------------------------------------------------------------------------------------------|------------------------------------------------------------------------------------------------------------------------------------------|-----------------------------------------------------------------------------------------|
|
|----|----|----|----|----|
|
||||||
| <ul><li> - \[x] Kernel WireGuard </ul></li> | <ul><li> - \[x] [Admin Web UI](https://github.com/netbirdio/dashboard) </ul></li> | <ul><li> - \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login) </ul></li> | <ul><li> - \[x] [Public API](https://docs.netbird.io/api) </ul></li> | <ul><li> - \[x] Linux </ul></li> |
|
| <ul><li>- \[x] Kernel WireGuard</ul></li> | <ul><li>- \[x] [Admin Web UI](https://github.com/netbirdio/dashboard)</ul></li> | <ul><li>- \[x] [SSO & MFA support](https://docs.netbird.io/how-to/installation#running-net-bird-with-sso-login)</ul></li> | <ul><li>- \[x] [Public API](https://docs.netbird.io/api)</ul></li> | <ul><li>- \[x] Linux</ul></li> |
|
||||||
| <ul><li> - \[x] Peer-to-peer connections </ul></li> | <ul><li> - \[x] Auto peer discovery and configuration </ul></li> | <ul><li> - \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access) </ul></li> | <ul><li> - \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys) </ul></li> | <ul><li> - \[x] Mac </ul></li> |
|
| <ul><li>- \[x] Peer-to-peer connections</ul></li> | <ul><li>- \[x] Auto peer discovery and configuration</ui></li> | <ul><li>- \[x] [Access control - groups & rules](https://docs.netbird.io/how-to/manage-network-access)</ui></li> | <ul><li>- \[x] [Setup keys for bulk network provisioning](https://docs.netbird.io/how-to/register-machines-using-setup-keys)</ui></li> | <ul><li>- \[x] Mac</ui></li> |
|
||||||
| <ul><li> - \[x] Connection relay fallback </ul></li> | <ul><li> - \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers) </ul></li> | <ul><li> - \[x] [Activity logging](https://docs.netbird.io/how-to/monitor-system-and-network-activity) </ul></li> | <ul><li> - \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart) </ul></li> | <ul><li> - \[x] Windows </ul></li> |
|
| <ul><li>- \[x] Connection relay fallback</ui></li> | <ul><li>- \[x] [IdP integrations](https://docs.netbird.io/selfhosted/identity-providers)</ui></li> | <ul><li>- \[x] [Activity logging](https://docs.netbird.io/how-to/audit-events-logging)</ui></li> | <ul><li>- \[x] [Self-hosting quickstart script](https://docs.netbird.io/selfhosted/selfhosted-quickstart)</ui></li> | <ul><li>- \[x] Windows</ui></li> |
|
||||||
| <ul><li> - \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks) </ul></li> | <ul><li> - \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network) </ul></li> | <ul><li> - \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks) </ul></li> | <ul><li> - \[x] IdP groups sync with JWT </ul></li> | <ul><li> - \[x] Android </ul></li> |
|
| <ul><li>- \[x] [Routes to external networks](https://docs.netbird.io/how-to/routing-traffic-to-private-networks)</ui></li> | <ul><li>- \[x] [Private DNS](https://docs.netbird.io/how-to/manage-dns-in-your-network)</ui></li> | <ul><li>- \[x] [Device posture checks](https://docs.netbird.io/how-to/manage-posture-checks)</ui></li> | <ul><li>- \[x] IdP groups sync with JWT</ui></li> | <ul><li>- \[x] Android</ui></li> |
|
||||||
| <ul><li> - \[x] NAT traversal with BPF </ul></li> | <ul><li> - \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network) </ul></li> | <ul><li> - \[x] Peer-to-peer encryption </ul></li> | | <ul><li> - \[x] iOS </ul></li> |
|
| <ul><li>- \[x] NAT traversal with BPF</ui></li> | <ul><li>- \[x] [Multiuser support](https://docs.netbird.io/how-to/add-users-to-your-network)</ui></li> | <ul><li>- \[x] Peer-to-peer encryption</ui></li> || <ul><li>- \[x] iOS</ui></li> |
|
||||||
| | | <ul><li> - \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn) </ul></li> | | <ul><li> - \[x] OpenWRT </ul></li> |
|
||| <ul><li>- \[x] [Quantum-resistance with Rosenpass](https://netbird.io/knowledge-hub/the-first-quantum-resistant-mesh-vpn)</ui></li> || <ul><li>- \[x] OpenWRT</ui></li> |
|
||||||
| | | <ui><li> - \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ul></li> | | <ul><li> - \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas) </ul></li> |
|
||| <ul><li>- \[x] [Periodic re-authentication](https://docs.netbird.io/how-to/enforce-periodic-user-authentication)</ui></li> || <ul><li>- \[x] [Serverless](https://docs.netbird.io/how-to/netbird-on-faas)</ui></li> |
|
||||||
| | | | | <ul><li> - \[x] Docker </ul></li> |
|
||||| <ul><li>- \[x] Docker</ui></li> |
|
||||||
|
|
||||||
### Quickstart with NetBird Cloud
|
### Quickstart with NetBird Cloud
|
||||||
|
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ type Anonymizer struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
func DefaultAddresses() (netip.Addr, netip.Addr) {
|
||||||
// 192.51.100.0, 100::
|
// 198.51.100.0, 100::
|
||||||
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
return netip.AddrFrom4([4]byte{198, 51, 100, 0}), netip.AddrFrom16([16]byte{0x01})
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -11,9 +11,12 @@ import (
|
|||||||
"google.golang.org/grpc/status"
|
"google.golang.org/grpc/status"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal"
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/debug"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
"github.com/netbirdio/netbird/client/server"
|
"github.com/netbirdio/netbird/client/server"
|
||||||
nbstatus "github.com/netbirdio/netbird/client/status"
|
nbstatus "github.com/netbirdio/netbird/client/status"
|
||||||
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errCloseConnection = "Failed to close connection: %v"
|
const errCloseConnection = "Failed to close connection: %v"
|
||||||
@@ -326,3 +329,34 @@ func formatDuration(d time.Duration) string {
|
|||||||
s := d / time.Second
|
s := d / time.Second
|
||||||
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
return fmt.Sprintf("%02d:%02d:%02d", h, m, s)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func generateDebugBundle(config *internal.Config, recorder *peer.Status, connectClient *internal.ConnectClient, logFilePath string) {
|
||||||
|
var networkMap *mgmProto.NetworkMap
|
||||||
|
var err error
|
||||||
|
|
||||||
|
if connectClient != nil {
|
||||||
|
networkMap, err = connectClient.GetLatestNetworkMap()
|
||||||
|
if err != nil {
|
||||||
|
log.Warnf("Failed to get latest network map: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
bundleGenerator := debug.NewBundleGenerator(
|
||||||
|
debug.GeneratorDependencies{
|
||||||
|
InternalConfig: config,
|
||||||
|
StatusRecorder: recorder,
|
||||||
|
NetworkMap: networkMap,
|
||||||
|
LogFile: logFilePath,
|
||||||
|
},
|
||||||
|
debug.BundleConfig{
|
||||||
|
IncludeSystemInfo: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
path, err := bundleGenerator.Generate()
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to generate debug bundle: %v", err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Generated debug bundle from SIGUSR1 at: %s", path)
|
||||||
|
}
|
||||||
|
|||||||
39
client/cmd/debug_unix.go
Normal file
39
client/cmd/debug_unix.go
Normal file
@@ -0,0 +1,39 @@
|
|||||||
|
//go:build unix
|
||||||
|
|
||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
func SetupDebugHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
) {
|
||||||
|
usr1Ch := make(chan os.Signal, 1)
|
||||||
|
|
||||||
|
signal.Notify(usr1Ch, syscall.SIGUSR1)
|
||||||
|
|
||||||
|
go func() {
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
case <-usr1Ch:
|
||||||
|
log.Info("Received SIGUSR1. Triggering debug bundle generation.")
|
||||||
|
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
}
|
||||||
126
client/cmd/debug_windows.go
Normal file
126
client/cmd/debug_windows.go
Normal file
@@ -0,0 +1,126 @@
|
|||||||
|
package cmd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"errors"
|
||||||
|
"os"
|
||||||
|
"strconv"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/sys/windows"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
|
)
|
||||||
|
|
||||||
|
const (
|
||||||
|
envListenEvent = "NB_LISTEN_DEBUG_EVENT"
|
||||||
|
debugTriggerEventName = `Global\NetbirdDebugTriggerEvent`
|
||||||
|
|
||||||
|
waitTimeout = 5 * time.Second
|
||||||
|
)
|
||||||
|
|
||||||
|
// SetupDebugHandler sets up a Windows event to listen for a signal to generate a debug bundle.
|
||||||
|
// Example usage with PowerShell:
|
||||||
|
// $evt = [System.Threading.EventWaitHandle]::OpenExisting("Global\NetbirdDebugTriggerEvent")
|
||||||
|
// $evt.Set()
|
||||||
|
// $evt.Close()
|
||||||
|
func SetupDebugHandler(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
) {
|
||||||
|
env := os.Getenv(envListenEvent)
|
||||||
|
if env == "" {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
listenEvent, err := strconv.ParseBool(env)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to parse %s: %v", envListenEvent, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if !listenEvent {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
eventNamePtr, err := windows.UTF16PtrFromString(debugTriggerEventName)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to convert event name '%s' to UTF16: %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: restrict access by ACL
|
||||||
|
eventHandle, err := windows.CreateEvent(nil, 1, 0, eventNamePtr)
|
||||||
|
if err != nil {
|
||||||
|
if errors.Is(err, windows.ERROR_ALREADY_EXISTS) {
|
||||||
|
log.Warnf("Debug trigger event '%s' already exists. Attempting to open.", debugTriggerEventName)
|
||||||
|
// SYNCHRONIZE is needed for WaitForSingleObject, EVENT_MODIFY_STATE for ResetEvent.
|
||||||
|
eventHandle, err = windows.OpenEvent(windows.SYNCHRONIZE|windows.EVENT_MODIFY_STATE, false, eventNamePtr)
|
||||||
|
if err != nil {
|
||||||
|
log.Errorf("Failed to open existing debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Infof("Successfully opened existing debug trigger event '%s'.", debugTriggerEventName)
|
||||||
|
} else {
|
||||||
|
log.Errorf("Failed to create debug trigger event '%s': %v", debugTriggerEventName, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if eventHandle == windows.InvalidHandle {
|
||||||
|
log.Errorf("Obtained an invalid handle for debug trigger event '%s'", debugTriggerEventName)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Infof("Debug handler waiting for signal on event: %s", debugTriggerEventName)
|
||||||
|
|
||||||
|
go waitForEvent(ctx, config, recorder, connectClient, logFilePath, eventHandle)
|
||||||
|
}
|
||||||
|
|
||||||
|
func waitForEvent(
|
||||||
|
ctx context.Context,
|
||||||
|
config *internal.Config,
|
||||||
|
recorder *peer.Status,
|
||||||
|
connectClient *internal.ConnectClient,
|
||||||
|
logFilePath string,
|
||||||
|
eventHandle windows.Handle,
|
||||||
|
) {
|
||||||
|
defer func() {
|
||||||
|
if err := windows.CloseHandle(eventHandle); err != nil {
|
||||||
|
log.Errorf("Failed to close debug event handle '%s': %v", debugTriggerEventName, err)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
for {
|
||||||
|
if ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
status, err := windows.WaitForSingleObject(eventHandle, uint32(waitTimeout.Milliseconds()))
|
||||||
|
|
||||||
|
switch status {
|
||||||
|
case windows.WAIT_OBJECT_0:
|
||||||
|
log.Info("Received signal on debug event. Triggering debug bundle generation.")
|
||||||
|
|
||||||
|
// reset the event so it can be triggered again later (manual reset == 1)
|
||||||
|
if err := windows.ResetEvent(eventHandle); err != nil {
|
||||||
|
log.Errorf("Failed to reset debug event '%s': %v", debugTriggerEventName, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
go generateDebugBundle(config, recorder, connectClient, logFilePath)
|
||||||
|
case uint32(windows.WAIT_TIMEOUT):
|
||||||
|
|
||||||
|
default:
|
||||||
|
log.Errorf("Unexpected status %d from WaitForSingleObject for debug event '%s': %v", status, debugTriggerEventName, err)
|
||||||
|
select {
|
||||||
|
case <-time.After(5 * time.Second):
|
||||||
|
case <-ctx.Done():
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -19,6 +19,10 @@ import (
|
|||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
loginCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
}
|
||||||
|
|
||||||
var loginCmd = &cobra.Command{
|
var loginCmd = &cobra.Command{
|
||||||
Use: "login",
|
Use: "login",
|
||||||
Short: "login to the Netbird Management Service (first run)",
|
Short: "login to the Netbird Management Service (first run)",
|
||||||
@@ -127,7 +131,7 @@ var loginCmd = &cobra.Command{
|
|||||||
}
|
}
|
||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
@@ -198,7 +202,7 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
return nil, fmt.Errorf("getting a request OAuth flow info failed: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode)
|
openURL(cmd, flowInfo.VerificationURIComplete, flowInfo.UserCode, noBrowser)
|
||||||
|
|
||||||
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
waitTimeout := time.Duration(flowInfo.ExpiresIn) * time.Second
|
||||||
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
waitCTX, c := context.WithTimeout(context.TODO(), waitTimeout)
|
||||||
@@ -212,20 +216,28 @@ func foregroundGetTokenInfo(ctx context.Context, cmd *cobra.Command, config *int
|
|||||||
return &tokenInfo, nil
|
return &tokenInfo, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string) {
|
func openURL(cmd *cobra.Command, verificationURIComplete, userCode string, noBrowser bool) {
|
||||||
var codeMsg string
|
var codeMsg string
|
||||||
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
if userCode != "" && !strings.Contains(verificationURIComplete, userCode) {
|
||||||
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
codeMsg = fmt.Sprintf("and enter the code %s to authenticate.", userCode)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if noBrowser {
|
||||||
|
cmd.Println("Use this URL to log in:\n\n" + verificationURIComplete + " " + codeMsg)
|
||||||
|
} else {
|
||||||
cmd.Println("Please do the SSO login in your browser. \n" +
|
cmd.Println("Please do the SSO login in your browser. \n" +
|
||||||
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
"If your browser didn't open automatically, use this URL to log in:\n\n" +
|
||||||
verificationURIComplete + " " + codeMsg)
|
verificationURIComplete + " " + codeMsg)
|
||||||
|
}
|
||||||
|
|
||||||
cmd.Println("")
|
cmd.Println("")
|
||||||
|
|
||||||
|
if !noBrowser {
|
||||||
if err := open.Run(verificationURIComplete); err != nil {
|
if err := open.Run(verificationURIComplete); err != nil {
|
||||||
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
cmd.Println("\nAlternatively, you may want to use a setup key, see:\n\n" +
|
||||||
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
"https://docs.netbird.io/how-to/register-machines-using-setup-keys")
|
||||||
}
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
// isLinuxRunningDesktop checks if a Linux OS is running desktop environment
|
||||||
|
|||||||
@@ -180,7 +180,7 @@ func init() {
|
|||||||
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
upCmd.PersistentFlags().BoolVar(&serverSSHAllowed, serverSSHAllowedFlag, false, "Allow SSH server on peer. If enabled, the SSH server will be permitted")
|
||||||
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
upCmd.PersistentFlags().BoolVar(&autoConnectDisabled, disableAutoConnectFlag, false, "Disables auto-connect feature. If enabled, then the client won't connect automatically when the service starts.")
|
||||||
|
|
||||||
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", false, "Adds system information to the debug bundle")
|
debugCmd.PersistentFlags().BoolVarP(&debugSystemInfoFlag, systemInfoFlag, "S", true, "Adds system information to the debug bundle")
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetupCloseHandler handles SIGTERM signal and exits with success
|
// SetupCloseHandler handles SIGTERM signal and exits with success
|
||||||
|
|||||||
@@ -115,6 +115,7 @@ var runCmd = &cobra.Command{
|
|||||||
|
|
||||||
ctx, cancel := context.WithCancel(cmd.Context())
|
ctx, cancel := context.WithCancel(cmd.Context())
|
||||||
SetupCloseHandler(ctx, cancel)
|
SetupCloseHandler(ctx, cancel)
|
||||||
|
SetupDebugHandler(ctx, nil, nil, nil, logFile)
|
||||||
|
|
||||||
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
s, err := newSVC(newProgram(ctx, cancel), newSVCConfig())
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -6,14 +6,17 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
"go.opentelemetry.io/otel"
|
"go.opentelemetry.io/otel"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/util"
|
"github.com/netbirdio/netbird/util"
|
||||||
|
|
||||||
@@ -31,7 +34,7 @@ import (
|
|||||||
|
|
||||||
func startTestingServices(t *testing.T) string {
|
func startTestingServices(t *testing.T) string {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
config := &mgmt.Config{}
|
config := &types.Config{}
|
||||||
_, err := util.ReadJson("../testdata/management.json", config)
|
_, err := util.ReadJson("../testdata/management.json", config)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
@@ -66,7 +69,7 @@ func startSignal(t *testing.T) (*grpc.Server, net.Listener) {
|
|||||||
return s, lis
|
return s, lis
|
||||||
}
|
}
|
||||||
|
|
||||||
func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.Server, net.Listener) {
|
func startManagement(t *testing.T, config *types.Config, testFile string) (*grpc.Server, net.Listener) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
lis, err := net.Listen("tcp", ":0")
|
lis, err := net.Listen("tcp", ":0")
|
||||||
@@ -89,14 +92,24 @@ func startManagement(t *testing.T, config *mgmt.Config, testFile string) (*grpc.
|
|||||||
|
|
||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
|
||||||
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock())
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
permissionsManagerMock := permissions.NewMockManager(ctrl)
|
||||||
|
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.Settings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
accountManager, err := mgmt.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, iv, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManagerMock)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := mgmt.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||||
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := mgmt.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatal(err)
|
t.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -32,12 +32,16 @@ const (
|
|||||||
|
|
||||||
const (
|
const (
|
||||||
dnsLabelsFlag = "extra-dns-labels"
|
dnsLabelsFlag = "extra-dns-labels"
|
||||||
|
|
||||||
|
noBrowserFlag = "no-browser"
|
||||||
|
noBrowserDesc = "do not open the browser for SSO login"
|
||||||
)
|
)
|
||||||
|
|
||||||
var (
|
var (
|
||||||
foregroundMode bool
|
foregroundMode bool
|
||||||
dnsLabels []string
|
dnsLabels []string
|
||||||
dnsLabelsValidated domain.List
|
dnsLabelsValidated domain.List
|
||||||
|
noBrowser bool
|
||||||
|
|
||||||
upCmd = &cobra.Command{
|
upCmd = &cobra.Command{
|
||||||
Use: "up",
|
Use: "up",
|
||||||
@@ -65,6 +69,9 @@ func init() {
|
|||||||
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
`E.g. --extra-dns-labels vpc1 or --extra-dns-labels vpc1,mgmt1 `+
|
||||||
`or --extra-dns-labels ""`,
|
`or --extra-dns-labels ""`,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
upCmd.PersistentFlags().BoolVar(&noBrowser, noBrowserFlag, false, noBrowserDesc)
|
||||||
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func upFunc(cmd *cobra.Command, args []string) error {
|
func upFunc(cmd *cobra.Command, args []string) error {
|
||||||
@@ -212,6 +219,8 @@ func runInForegroundMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
r.GetFullStatus()
|
r.GetFullStatus()
|
||||||
|
|
||||||
connectClient := internal.NewConnectClient(ctx, config, r)
|
connectClient := internal.NewConnectClient(ctx, config, r)
|
||||||
|
SetupDebugHandler(ctx, config, r, connectClient, "")
|
||||||
|
|
||||||
return connectClient.Run(nil)
|
return connectClient.Run(nil)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -349,7 +358,7 @@ func runInDaemonMode(ctx context.Context, cmd *cobra.Command) error {
|
|||||||
|
|
||||||
if loginResp.NeedsSSOLogin {
|
if loginResp.NeedsSSOLogin {
|
||||||
|
|
||||||
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode)
|
openURL(cmd, loginResp.VerificationURIComplete, loginResp.UserCode, noBrowser)
|
||||||
|
|
||||||
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
_, err = client.WaitSSOLogin(ctx, &proto.WaitSSOLoginRequest{UserCode: loginResp.UserCode, Hostname: hostName})
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
@@ -10,17 +10,18 @@ import (
|
|||||||
|
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
// NewFirewall creates a firewall manager instance
|
// NewFirewall creates a firewall manager instance
|
||||||
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, _ *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
if !iface.IsUserspaceBind() {
|
if !iface.IsUserspaceBind() {
|
||||||
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
return nil, fmt.Errorf("not implemented for this OS: %s", runtime.GOOS)
|
||||||
}
|
}
|
||||||
|
|
||||||
// use userspace packet filtering firewall
|
// use userspace packet filtering firewall
|
||||||
fm, err := uspfilter.Create(iface, disableServerRoutes)
|
fm, err := uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,6 +15,7 @@ import (
|
|||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
nbnftables "github.com/netbirdio/netbird/client/firewall/nftables"
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -33,7 +34,7 @@ const SKIP_NFTABLES_ENV = "NB_SKIP_NFTABLES_CHECK"
|
|||||||
// FWType is the type for the firewall type
|
// FWType is the type for the firewall type
|
||||||
type FWType int
|
type FWType int
|
||||||
|
|
||||||
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, flowLogger nftypes.FlowLogger, disableServerRoutes bool) (firewall.Manager, error) {
|
||||||
// on the linux system we try to user nftables or iptables
|
// on the linux system we try to user nftables or iptables
|
||||||
// in any case, because we need to allow netbird interface traffic
|
// in any case, because we need to allow netbird interface traffic
|
||||||
// so we use AllowNetbird traffic from these firewall managers
|
// so we use AllowNetbird traffic from these firewall managers
|
||||||
@@ -47,7 +48,7 @@ func NewFirewall(iface IFaceMapper, stateManager *statemanager.Manager, disableS
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
log.Warnf("failed to create native firewall: %v. Proceeding with userspace", err)
|
||||||
}
|
}
|
||||||
return createUserspaceFirewall(iface, fm, disableServerRoutes)
|
return createUserspaceFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
func createNativeFirewall(iface IFaceMapper, stateManager *statemanager.Manager, routes bool) (firewall.Manager, error) {
|
||||||
@@ -77,12 +78,12 @@ func createFW(iface IFaceMapper) (firewall.Manager, error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool) (firewall.Manager, error) {
|
func createUserspaceFirewall(iface IFaceMapper, fm firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (firewall.Manager, error) {
|
||||||
var errUsp error
|
var errUsp error
|
||||||
if fm != nil {
|
if fm != nil {
|
||||||
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes)
|
fm, errUsp = uspfilter.CreateWithNativeFirewall(iface, fm, disableServerRoutes, flowLogger)
|
||||||
} else {
|
} else {
|
||||||
fm, errUsp = uspfilter.Create(iface, disableServerRoutes)
|
fm, errUsp = uspfilter.Create(iface, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if errUsp != nil {
|
if errUsp != nil {
|
||||||
|
|||||||
@@ -75,6 +75,7 @@ func (m *aclManager) init(stateManager *statemanager.Manager) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *aclManager) AddPeerFiltering(
|
func (m *aclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
|
|||||||
@@ -96,21 +96,22 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
//
|
//
|
||||||
// Comment will be ignored because some system this feature is not supported
|
// Comment will be ignored because some system this feature is not supported
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
_ string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
return m.aclMgr.AddPeerFiltering(ip, protocol, sPort, dPort, action, ipsetName)
|
return m.aclMgr.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -125,7 +126,7 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -196,13 +197,13 @@ func (m *Manager) AllowNetbird() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
_, err := m.AddPeerFiltering(
|
_, err := m.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
"all",
|
"all",
|
||||||
nil,
|
nil,
|
||||||
nil,
|
nil,
|
||||||
firewall.ActionAccept,
|
firewall.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
return fmt.Errorf("allow netbird interface traffic: %w", err)
|
||||||
|
|||||||
@@ -75,7 +75,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
IsRange: true,
|
IsRange: true,
|
||||||
Values: []uint16{8043, 8046},
|
Values: []uint16{8043, 8046},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "", "accept HTTPS traffic from ports range")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
@@ -97,7 +97,7 @@ func TestIptablesManager(t *testing.T) {
|
|||||||
// add second rule
|
// add second rule
|
||||||
ip := net.ParseIP("10.20.0.3")
|
ip := net.ParseIP("10.20.0.3")
|
||||||
port := &fw.Port{Values: []uint16{5353}}
|
port := &fw.Port{Values: []uint16{5353}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "udp", nil, port, fw.ActionAccept, "", "accept Fake DNS traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "udp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Close(nil)
|
err = manager.Close(nil)
|
||||||
@@ -148,7 +148,7 @@ func TestIptablesManagerIPSet(t *testing.T) {
|
|||||||
port := &fw.Port{
|
port := &fw.Port{
|
||||||
Values: []uint16{443},
|
Values: []uint16{443},
|
||||||
}
|
}
|
||||||
rule2, err = manager.AddPeerFiltering(ip, "tcp", port, nil, fw.ActionAccept, "default", "accept HTTPS traffic from ports range")
|
rule2, err = manager.AddPeerFiltering(nil, ip, "tcp", port, nil, fw.ActionAccept, "default")
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
require.Equal(t, r.(*Rule).ipsetName, "default-sport", "ipset name must be set")
|
||||||
@@ -216,7 +216,7 @@ func TestIptablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -41,6 +41,8 @@ const (
|
|||||||
jumpManglePre = "jump-mangle-pre"
|
jumpManglePre = "jump-mangle-pre"
|
||||||
jumpNatPre = "jump-nat-pre"
|
jumpNatPre = "jump-nat-pre"
|
||||||
jumpNatPost = "jump-nat-post"
|
jumpNatPost = "jump-nat-post"
|
||||||
|
markManglePre = "mark-mangle-pre"
|
||||||
|
markManglePost = "mark-mangle-post"
|
||||||
matchSet = "--match-set"
|
matchSet = "--match-set"
|
||||||
|
|
||||||
dnatSuffix = "_dnat"
|
dnatSuffix = "_dnat"
|
||||||
@@ -115,12 +117,17 @@ func (r *router) init(stateManager *statemanager.Manager) error {
|
|||||||
return fmt.Errorf("create containers: %w", err)
|
return fmt.Errorf("create containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.setupDataPlaneMark(); err != nil {
|
||||||
|
log.Errorf("failed to set up data plane mark: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
r.updateState()
|
r.updateState()
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -128,7 +135,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
@@ -347,12 +354,16 @@ func (r *router) Reset() error {
|
|||||||
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
if err := r.cleanUpDefaultForwardRules(); err != nil {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
r.rules = make(map[string][]string)
|
|
||||||
|
|
||||||
if err := r.ipsetCounter.Flush(); err != nil {
|
if err := r.ipsetCounter.Flush(); err != nil {
|
||||||
merr = multierror.Append(merr, err)
|
merr = multierror.Append(merr, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.cleanupDataPlaneMark(); err != nil {
|
||||||
|
merr = multierror.Append(merr, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.rules = make(map[string][]string)
|
||||||
r.updateState()
|
r.updateState()
|
||||||
|
|
||||||
return nberrors.FormatErrorOrNil(merr)
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
@@ -422,6 +433,57 @@ func (r *router) createContainers() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// setupDataPlaneMark configures the fwmark for the data plane
|
||||||
|
func (r *router) setupDataPlaneMark() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
preRule := []string{
|
||||||
|
"-i", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "NEW",
|
||||||
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkIn),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add mangle prerouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
r.rules[markManglePre] = preRule
|
||||||
|
}
|
||||||
|
|
||||||
|
postRule := []string{
|
||||||
|
"-o", r.wgIface.Name(),
|
||||||
|
"-m", "conntrack", "--ctstate", "NEW",
|
||||||
|
"-j", "CONNMARK", "--set-mark", fmt.Sprintf("%#x", nbnet.DataPlaneMarkOut),
|
||||||
|
}
|
||||||
|
|
||||||
|
if err := r.iptablesClient.AppendUnique(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("add mangle postrouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
r.rules[markManglePost] = postRule
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (r *router) cleanupDataPlaneMark() error {
|
||||||
|
var merr *multierror.Error
|
||||||
|
if preRule, exists := r.rules[markManglePre]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPREROUTING, preRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle prerouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, markManglePre)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
if postRule, exists := r.rules[markManglePost]; exists {
|
||||||
|
if err := r.iptablesClient.DeleteIfExists(tableMangle, chainPOSTROUTING, postRule...); err != nil {
|
||||||
|
merr = multierror.Append(merr, fmt.Errorf("remove mangle postrouting rule: %w", err))
|
||||||
|
} else {
|
||||||
|
delete(r.rules, markManglePost)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nberrors.FormatErrorOrNil(merr)
|
||||||
|
}
|
||||||
|
|
||||||
func (r *router) addPostroutingRules() error {
|
func (r *router) addPostroutingRules() error {
|
||||||
// First rule for outbound masquerade
|
// First rule for outbound masquerade
|
||||||
rule1 := []string{
|
rule1 := []string{
|
||||||
@@ -463,7 +525,7 @@ func (r *router) insertEstablishedRule(chain string) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *router) addJumpRules() error {
|
func (r *router) addJumpRules() error {
|
||||||
// Jump to NAT chain
|
// Jump to nat chain
|
||||||
natRule := []string{"-j", chainRTNAT}
|
natRule := []string{"-j", chainRTNAT}
|
||||||
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
if err := r.iptablesClient.Insert(tableNat, chainPOSTROUTING, 1, natRule...); err != nil {
|
||||||
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
return fmt.Errorf("add nat postrouting jump rule: %v", err)
|
||||||
|
|||||||
@@ -46,7 +46,9 @@ func TestIptablesManager_RestoreOrCreateContainers(t *testing.T) {
|
|||||||
// 5. jump rule to PRE nat chain
|
// 5. jump rule to PRE nat chain
|
||||||
// 6. static outbound masquerade rule
|
// 6. static outbound masquerade rule
|
||||||
// 7. static return masquerade rule
|
// 7. static return masquerade rule
|
||||||
require.Len(t, manager.rules, 7, "should have created rules map")
|
// 8. mangle prerouting mark rule
|
||||||
|
// 9. mangle postrouting mark rule
|
||||||
|
require.Len(t, manager.rules, 9, "should have created rules map")
|
||||||
|
|
||||||
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
exists, err := manager.iptablesClient.Exists(tableNat, chainPOSTROUTING, "-j", chainRTNAT)
|
||||||
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
require.NoError(t, err, "should be able to query the iptables %s table and %s chain", tableNat, chainPOSTROUTING)
|
||||||
@@ -330,7 +332,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
// Check if the rule is in the internal map
|
// Check if the rule is in the internal map
|
||||||
|
|||||||
@@ -65,13 +65,13 @@ type Manager interface {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
AddPeerFiltering(
|
AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto Protocol,
|
proto Protocol,
|
||||||
sPort *Port,
|
sPort *Port,
|
||||||
dPort *Port,
|
dPort *Port,
|
||||||
action Action,
|
action Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]Rule, error)
|
) ([]Rule, error)
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
@@ -80,7 +80,15 @@ type Manager interface {
|
|||||||
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
// IsServerRouteSupported returns true if the firewall supports server side routing operations
|
||||||
IsServerRouteSupported() bool
|
IsServerRouteSupported() bool
|
||||||
|
|
||||||
AddRouteFiltering(source []netip.Prefix, destination netip.Prefix, proto Protocol, sPort *Port, dPort *Port, action Action) (Rule, error)
|
AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
|
sources []netip.Prefix,
|
||||||
|
destination netip.Prefix,
|
||||||
|
proto Protocol,
|
||||||
|
sPort *Port,
|
||||||
|
dPort *Port,
|
||||||
|
action Action,
|
||||||
|
) (Rule, error)
|
||||||
|
|
||||||
// DeleteRouteRule deletes a routing rule
|
// DeleteRouteRule deletes a routing rule
|
||||||
DeleteRouteRule(rule Rule) error
|
DeleteRouteRule(rule Rule) error
|
||||||
|
|||||||
@@ -27,7 +27,8 @@ const (
|
|||||||
// filter chains contains the rules that jump to the rules chains
|
// filter chains contains the rules that jump to the rules chains
|
||||||
chainNameInputFilter = "netbird-acl-input-filter"
|
chainNameInputFilter = "netbird-acl-input-filter"
|
||||||
chainNameForwardFilter = "netbird-acl-forward-filter"
|
chainNameForwardFilter = "netbird-acl-forward-filter"
|
||||||
chainNamePrerouting = "netbird-rt-prerouting"
|
chainNameManglePrerouting = "netbird-mangle-prerouting"
|
||||||
|
chainNameManglePostrouting = "netbird-mangle-postrouting"
|
||||||
|
|
||||||
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
allowNetbirdInputRuleID = "allow Netbird incoming traffic"
|
||||||
)
|
)
|
||||||
@@ -84,13 +85,13 @@ func (m *AclManager) init(workTable *nftables.Table) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *AclManager) AddPeerFiltering(
|
func (m *AclManager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
var ipset *nftables.Set
|
var ipset *nftables.Set
|
||||||
if ipsetName != "" {
|
if ipsetName != "" {
|
||||||
@@ -102,7 +103,7 @@ func (m *AclManager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
newRules := make([]firewall.Rule, 0, 2)
|
newRules := make([]firewall.Rule, 0, 2)
|
||||||
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset, comment)
|
ioRule, err := m.addIOFiltering(ip, proto, sPort, dPort, action, ipset)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -256,7 +257,6 @@ func (m *AclManager) addIOFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipset *nftables.Set,
|
ipset *nftables.Set,
|
||||||
comment string,
|
|
||||||
) (*Rule, error) {
|
) (*Rule, error) {
|
||||||
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
ruleId := generatePeerRuleId(ip, sPort, dPort, action, ipset)
|
||||||
if r, ok := m.rules[ruleId]; ok {
|
if r, ok := m.rules[ruleId]; ok {
|
||||||
@@ -338,7 +338,7 @@ func (m *AclManager) addIOFiltering(
|
|||||||
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
mainExpressions = append(mainExpressions, &expr.Verdict{Kind: expr.VerdictDrop})
|
||||||
}
|
}
|
||||||
|
|
||||||
userData := []byte(strings.Join([]string{ruleId, comment}, " "))
|
userData := []byte(ruleId)
|
||||||
|
|
||||||
chain := m.chainInputRules
|
chain := m.chainInputRules
|
||||||
nftRule := m.rConn.AddRule(&nftables.Rule{
|
nftRule := m.rConn.AddRule(&nftables.Rule{
|
||||||
@@ -463,13 +463,15 @@ func (m *AclManager) createDefaultChains() (err error) {
|
|||||||
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
// go through the input filter as well. This will enable e.g. Docker services to keep working by accessing the
|
||||||
// netbird peer IP.
|
// netbird peer IP.
|
||||||
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
func (m *AclManager) allowRedirectedTraffic(chainFwFilter *nftables.Chain) error {
|
||||||
m.chainPrerouting = m.rConn.AddChain(&nftables.Chain{
|
// Chain is created by route manager
|
||||||
Name: chainNamePrerouting,
|
// TODO: move creation to a common place
|
||||||
|
m.chainPrerouting = &nftables.Chain{
|
||||||
|
Name: chainNameManglePrerouting,
|
||||||
Table: m.workTable,
|
Table: m.workTable,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
})
|
}
|
||||||
|
|
||||||
m.addFwmarkToForward(chainFwFilter)
|
m.addFwmarkToForward(chainFwFilter)
|
||||||
|
|
||||||
|
|||||||
@@ -113,13 +113,13 @@ func (m *Manager) Init(stateManager *statemanager.Manager) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
@@ -129,10 +129,11 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
return nil, fmt.Errorf("unsupported IP version: %s", ip.String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.aclManager.AddPeerFiltering(ip, proto, sPort, dPort, action, ipsetName, comment)
|
return m.aclManager.AddPeerFiltering(id, ip, proto, sPort, dPort, action, ipsetName)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -147,7 +148,7 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
return nil, fmt.Errorf("unsupported IP version: %s", destination.Addr().String())
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.router.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.router.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DeletePeerRule from the firewall by rule definition
|
// DeletePeerRule from the firewall by rule definition
|
||||||
|
|||||||
@@ -74,7 +74,7 @@ func TestNftablesManager(t *testing.T) {
|
|||||||
|
|
||||||
testClient := &nftables.Conn{}
|
testClient := &nftables.Conn{}
|
||||||
|
|
||||||
rule, err := manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "", "")
|
rule, err := manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{53}}, fw.ActionDrop, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
err = manager.Flush()
|
err = manager.Flush()
|
||||||
@@ -201,7 +201,7 @@ func TestNFtablesCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
|
|
||||||
if i%100 == 0 {
|
if i%100 == 0 {
|
||||||
@@ -283,10 +283,11 @@ func TestNftablesManagerCompatibilityWithIptables(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
ip := net.ParseIP("100.96.0.1")
|
ip := net.ParseIP("100.96.0.1")
|
||||||
_, err = manager.AddPeerFiltering(ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "", "test rule")
|
_, err = manager.AddPeerFiltering(nil, ip, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept, "")
|
||||||
require.NoError(t, err, "failed to add peer filtering rule")
|
require.NoError(t, err, "failed to add peer filtering rule")
|
||||||
|
|
||||||
_, err = manager.AddRouteFiltering(
|
_, err = manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
[]netip.Prefix{netip.MustParsePrefix("192.168.2.0/24")},
|
||||||
netip.MustParsePrefix("10.1.0.0/24"),
|
netip.MustParsePrefix("10.1.0.0/24"),
|
||||||
fw.ProtocolTCP,
|
fw.ProtocolTCP,
|
||||||
|
|||||||
@@ -20,7 +20,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/id"
|
nbid "github.com/netbirdio/netbird/client/internal/acl/id"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
"github.com/netbirdio/netbird/client/internal/routemanager/ipfwdstate"
|
||||||
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
"github.com/netbirdio/netbird/client/internal/routemanager/refcounter"
|
||||||
nbnet "github.com/netbirdio/netbird/util/net"
|
nbnet "github.com/netbirdio/netbird/util/net"
|
||||||
@@ -100,6 +100,10 @@ func (r *router) init(workTable *nftables.Table) error {
|
|||||||
return fmt.Errorf("create containers: %w", err)
|
return fmt.Errorf("create containers: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.setupDataPlaneMark(); err != nil {
|
||||||
|
log.Errorf("failed to set up data plane mark: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -196,15 +200,21 @@ func (r *router) createContainers() error {
|
|||||||
Type: nftables.ChainTypeNAT,
|
Type: nftables.ChainTypeNAT,
|
||||||
})
|
})
|
||||||
|
|
||||||
// Chain is created by acl manager
|
r.chains[chainNameManglePostrouting] = r.conn.AddChain(&nftables.Chain{
|
||||||
// TODO: move creation to a common place
|
Name: chainNameManglePostrouting,
|
||||||
r.chains[chainNamePrerouting] = &nftables.Chain{
|
Table: r.workTable,
|
||||||
Name: chainNamePrerouting,
|
Hooknum: nftables.ChainHookPostrouting,
|
||||||
|
Priority: nftables.ChainPriorityMangle,
|
||||||
|
Type: nftables.ChainTypeFilter,
|
||||||
|
})
|
||||||
|
|
||||||
|
r.chains[chainNameManglePrerouting] = r.conn.AddChain(&nftables.Chain{
|
||||||
|
Name: chainNameManglePrerouting,
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Hooknum: nftables.ChainHookPrerouting,
|
Hooknum: nftables.ChainHookPrerouting,
|
||||||
Priority: nftables.ChainPriorityMangle,
|
Priority: nftables.ChainPriorityMangle,
|
||||||
Type: nftables.ChainTypeFilter,
|
Type: nftables.ChainTypeFilter,
|
||||||
}
|
})
|
||||||
|
|
||||||
// Add the single NAT rule that matches on mark
|
// Add the single NAT rule that matches on mark
|
||||||
if err := r.addPostroutingRules(); err != nil {
|
if err := r.addPostroutingRules(); err != nil {
|
||||||
@@ -220,7 +230,83 @@ func (r *router) createContainers() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
if err := r.conn.Flush(); err != nil {
|
if err := r.conn.Flush(); err != nil {
|
||||||
return fmt.Errorf("nftables: unable to initialize table: %v", err)
|
return fmt.Errorf("initialize tables: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// setupDataPlaneMark configures the fwmark for the data plane
|
||||||
|
func (r *router) setupDataPlaneMark() error {
|
||||||
|
if r.chains[chainNameManglePrerouting] == nil || r.chains[chainNameManglePostrouting] == nil {
|
||||||
|
return errors.New("no mangle chains found")
|
||||||
|
}
|
||||||
|
|
||||||
|
ctNew := getCtNewExprs()
|
||||||
|
preExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyIIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
preExprs = append(preExprs, ctNew...)
|
||||||
|
preExprs = append(preExprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkIn),
|
||||||
|
},
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
preNftRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameManglePrerouting],
|
||||||
|
Exprs: preExprs,
|
||||||
|
}
|
||||||
|
r.conn.AddRule(preNftRule)
|
||||||
|
|
||||||
|
postExprs := []expr.Any{
|
||||||
|
&expr.Meta{
|
||||||
|
Key: expr.MetaKeyOIFNAME,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpEq,
|
||||||
|
Register: 1,
|
||||||
|
Data: ifname(r.wgIface.Name()),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
postExprs = append(postExprs, ctNew...)
|
||||||
|
postExprs = append(postExprs,
|
||||||
|
&expr.Immediate{
|
||||||
|
Register: 1,
|
||||||
|
Data: binaryutil.NativeEndian.PutUint32(nbnet.DataPlaneMarkOut),
|
||||||
|
},
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeyMARK,
|
||||||
|
Register: 1,
|
||||||
|
SourceRegister: true,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
postNftRule := &nftables.Rule{
|
||||||
|
Table: r.workTable,
|
||||||
|
Chain: r.chains[chainNameManglePostrouting],
|
||||||
|
Exprs: postExprs,
|
||||||
|
}
|
||||||
|
r.conn.AddRule(postNftRule)
|
||||||
|
|
||||||
|
if err := r.conn.Flush(); err != nil {
|
||||||
|
return fmt.Errorf("flush: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
@@ -228,6 +314,7 @@ func (r *router) createContainers() error {
|
|||||||
|
|
||||||
// AddRouteFiltering appends a nftables rule to the routing chain
|
// AddRouteFiltering appends a nftables rule to the routing chain
|
||||||
func (r *router) AddRouteFiltering(
|
func (r *router) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -236,7 +323,7 @@ func (r *router) AddRouteFiltering(
|
|||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
|
|
||||||
ruleKey := id.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
ruleKey := nbid.GenerateRouteRuleKey(sources, destination, proto, sPort, dPort, action)
|
||||||
if _, ok := r.rules[string(ruleKey)]; ok {
|
if _, ok := r.rules[string(ruleKey)]; ok {
|
||||||
return ruleKey, nil
|
return ruleKey, nil
|
||||||
}
|
}
|
||||||
@@ -515,26 +602,10 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
op = expr.CmpOpNeq
|
op = expr.CmpOpNeq
|
||||||
}
|
}
|
||||||
|
|
||||||
exprs := []expr.Any{
|
|
||||||
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
// We only care about NEW connections to mark them and later identify them in the postrouting chain for masquerading.
|
||||||
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
// Masquerading will take care of the conntrack state, which means we won't need to mark established connections.
|
||||||
&expr.Ct{
|
exprs := getCtNewExprs()
|
||||||
Key: expr.CtKeySTATE,
|
exprs = append(exprs,
|
||||||
Register: 1,
|
|
||||||
},
|
|
||||||
&expr.Bitwise{
|
|
||||||
SourceRegister: 1,
|
|
||||||
DestRegister: 1,
|
|
||||||
Len: 4,
|
|
||||||
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
|
||||||
Xor: binaryutil.NativeEndian.PutUint32(0),
|
|
||||||
},
|
|
||||||
&expr.Cmp{
|
|
||||||
Op: expr.CmpOpNeq,
|
|
||||||
Register: 1,
|
|
||||||
Data: []byte{0, 0, 0, 0},
|
|
||||||
},
|
|
||||||
|
|
||||||
// interface matching
|
// interface matching
|
||||||
&expr.Meta{
|
&expr.Meta{
|
||||||
Key: expr.MetaKeyIIFNAME,
|
Key: expr.MetaKeyIIFNAME,
|
||||||
@@ -545,7 +616,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
Register: 1,
|
Register: 1,
|
||||||
Data: ifname(r.wgIface.Name()),
|
Data: ifname(r.wgIface.Name()),
|
||||||
},
|
},
|
||||||
}
|
)
|
||||||
|
|
||||||
exprs = append(exprs, sourceExp...)
|
exprs = append(exprs, sourceExp...)
|
||||||
exprs = append(exprs, destExp...)
|
exprs = append(exprs, destExp...)
|
||||||
@@ -577,7 +648,7 @@ func (r *router) addNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
r.rules[ruleKey] = r.conn.AddRule(&nftables.Rule{
|
||||||
Table: r.workTable,
|
Table: r.workTable,
|
||||||
Chain: r.chains[chainNamePrerouting],
|
Chain: r.chains[chainNameManglePrerouting],
|
||||||
Exprs: exprs,
|
Exprs: exprs,
|
||||||
UserData: []byte(ruleKey),
|
UserData: []byte(ruleKey),
|
||||||
})
|
})
|
||||||
@@ -1323,3 +1394,24 @@ func applyPort(port *firewall.Port, isSource bool) []expr.Any {
|
|||||||
|
|
||||||
return exprs
|
return exprs
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getCtNewExprs() []expr.Any {
|
||||||
|
return []expr.Any{
|
||||||
|
&expr.Ct{
|
||||||
|
Key: expr.CtKeySTATE,
|
||||||
|
Register: 1,
|
||||||
|
},
|
||||||
|
&expr.Bitwise{
|
||||||
|
SourceRegister: 1,
|
||||||
|
DestRegister: 1,
|
||||||
|
Len: 4,
|
||||||
|
Mask: binaryutil.NativeEndian.PutUint32(expr.CtStateBitNEW),
|
||||||
|
Xor: binaryutil.NativeEndian.PutUint32(0),
|
||||||
|
},
|
||||||
|
&expr.Cmp{
|
||||||
|
Op: expr.CmpOpNeq,
|
||||||
|
Register: 1,
|
||||||
|
Data: []byte{0, 0, 0, 0},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -100,7 +100,7 @@ func TestNftablesManager_AddNatRule(t *testing.T) {
|
|||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := 0
|
found := 0
|
||||||
for _, chain := range rtr.chains {
|
for _, chain := range rtr.chains {
|
||||||
if chain.Name == chainNamePrerouting {
|
if chain.Name == chainNameManglePrerouting {
|
||||||
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
rules, err := nftablesTestingClient.GetRules(chain.Table, chain)
|
||||||
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
require.NoError(t, err, "should list rules for %s table and %s chain", chain.Table.Name, chain.Name)
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
@@ -141,7 +141,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
// Verify the rule was added
|
// Verify the rule was added
|
||||||
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
natRuleKey := firewall.GenKey(firewall.PreroutingFormat, testCase.InputPair)
|
||||||
found := false
|
found := false
|
||||||
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
rules, err := rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules")
|
require.NoError(t, err, "should list rules")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
@@ -157,7 +157,7 @@ func TestNftablesManager_RemoveNatRule(t *testing.T) {
|
|||||||
|
|
||||||
// Verify the rule was removed
|
// Verify the rule was removed
|
||||||
found = false
|
found = false
|
||||||
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNamePrerouting])
|
rules, err = rtr.conn.GetRules(rtr.workTable, rtr.chains[chainNameManglePrerouting])
|
||||||
require.NoError(t, err, "should list rules after removal")
|
require.NoError(t, err, "should list rules after removal")
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
if len(rule.UserData) > 0 && string(rule.UserData) == natRuleKey {
|
||||||
@@ -311,7 +311,7 @@ func TestRouter_AddRouteFiltering(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
ruleKey, err := r.AddRouteFiltering(tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
ruleKey, err := r.AddRouteFiltering(nil, tt.sources, tt.destination, tt.proto, tt.sPort, tt.dPort, tt.action)
|
||||||
require.NoError(t, err, "AddRouteFiltering failed")
|
require.NoError(t, err, "AddRouteFiltering failed")
|
||||||
|
|
||||||
t.Cleanup(func() {
|
t.Cleanup(func() {
|
||||||
|
|||||||
@@ -4,6 +4,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -16,8 +17,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
@@ -31,8 +32,8 @@ func (m *Manager) Close(stateManager *statemanager.Manager) error {
|
|||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
|||||||
@@ -3,12 +3,14 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"fmt"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"os/exec"
|
"os/exec"
|
||||||
"syscall"
|
"syscall"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -20,28 +22,31 @@ const (
|
|||||||
firewallRuleName = "Netbird"
|
firewallRuleName = "Netbird"
|
||||||
)
|
)
|
||||||
|
|
||||||
// Close closes the firewall manager
|
// Reset firewall to the default state
|
||||||
func (m *Manager) Close(*statemanager.Manager) error {
|
func (m *Manager) Close(*statemanager.Manager) error {
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
m.outgoingRules = make(map[string]RuleSet)
|
m.outgoingRules = make(map[netip.Addr]RuleSet)
|
||||||
m.incomingRules = make(map[string]RuleSet)
|
m.incomingRules = make(map[netip.Addr]RuleSet)
|
||||||
|
|
||||||
if m.udpTracker != nil {
|
if m.udpTracker != nil {
|
||||||
m.udpTracker.Close()
|
m.udpTracker.Close()
|
||||||
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.icmpTracker != nil {
|
if m.icmpTracker != nil {
|
||||||
m.icmpTracker.Close()
|
m.icmpTracker.Close()
|
||||||
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.tcpTracker != nil {
|
if m.tcpTracker != nil {
|
||||||
m.tcpTracker.Close()
|
m.tcpTracker.Close()
|
||||||
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, m.flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.forwarder != nil {
|
if fwder := m.forwarder.Load(); fwder != nil {
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.logger != nil {
|
if m.logger != nil {
|
||||||
|
|||||||
@@ -1,20 +1,27 @@
|
|||||||
// common.go
|
|
||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"fmt"
|
||||||
"sync"
|
"net/netip"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// BaseConnTrack provides common fields and locking for all connection types
|
// BaseConnTrack provides common fields and locking for all connection types
|
||||||
type BaseConnTrack struct {
|
type BaseConnTrack struct {
|
||||||
SourceIP net.IP
|
FlowId uuid.UUID
|
||||||
DestIP net.IP
|
Direction nftypes.Direction
|
||||||
SourcePort uint16
|
SourceIP netip.Addr
|
||||||
DestPort uint16
|
DestIP netip.Addr
|
||||||
lastSeen atomic.Int64 // Unix nano for atomic access
|
lastSeen atomic.Int64
|
||||||
|
PacketsTx atomic.Uint64
|
||||||
|
PacketsRx atomic.Uint64
|
||||||
|
BytesTx atomic.Uint64
|
||||||
|
BytesRx atomic.Uint64
|
||||||
}
|
}
|
||||||
|
|
||||||
// these small methods will be inlined by the compiler
|
// these small methods will be inlined by the compiler
|
||||||
@@ -24,6 +31,17 @@ func (b *BaseConnTrack) UpdateLastSeen() {
|
|||||||
b.lastSeen.Store(time.Now().UnixNano())
|
b.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// UpdateCounters safely updates the packet and byte counters
|
||||||
|
func (b *BaseConnTrack) UpdateCounters(direction nftypes.Direction, bytes int) {
|
||||||
|
if direction == nftypes.Egress {
|
||||||
|
b.PacketsTx.Add(1)
|
||||||
|
b.BytesTx.Add(uint64(bytes))
|
||||||
|
} else {
|
||||||
|
b.PacketsRx.Add(1)
|
||||||
|
b.BytesRx.Add(uint64(bytes))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
// GetLastSeen safely gets the last seen timestamp
|
// GetLastSeen safely gets the last seen timestamp
|
||||||
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
func (b *BaseConnTrack) GetLastSeen() time.Time {
|
||||||
return time.Unix(0, b.lastSeen.Load())
|
return time.Unix(0, b.lastSeen.Load())
|
||||||
@@ -35,92 +53,14 @@ func (b *BaseConnTrack) timeoutExceeded(timeout time.Duration) bool {
|
|||||||
return time.Since(lastSeen) > timeout
|
return time.Since(lastSeen) > timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
// IPAddr is a fixed-size IP address to avoid allocations
|
|
||||||
type IPAddr [16]byte
|
|
||||||
|
|
||||||
// MakeIPAddr creates an IPAddr from net.IP
|
|
||||||
func MakeIPAddr(ip net.IP) (addr IPAddr) {
|
|
||||||
// Optimization: check for v4 first as it's more common
|
|
||||||
if ip4 := ip.To4(); ip4 != nil {
|
|
||||||
copy(addr[12:], ip4)
|
|
||||||
} else {
|
|
||||||
copy(addr[:], ip.To16())
|
|
||||||
}
|
|
||||||
return addr
|
|
||||||
}
|
|
||||||
|
|
||||||
// ConnKey uniquely identifies a connection
|
// ConnKey uniquely identifies a connection
|
||||||
type ConnKey struct {
|
type ConnKey struct {
|
||||||
SrcIP IPAddr
|
SrcIP netip.Addr
|
||||||
DstIP IPAddr
|
DstIP netip.Addr
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeConnKey creates a connection key
|
func (c ConnKey) String() string {
|
||||||
func makeConnKey(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) ConnKey {
|
return fmt.Sprintf("%s:%d -> %s:%d", c.SrcIP.Unmap(), c.SrcPort, c.DstIP.Unmap(), c.DstPort)
|
||||||
return ConnKey{
|
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
|
||||||
DstIP: MakeIPAddr(dstIP),
|
|
||||||
SrcPort: srcPort,
|
|
||||||
DstPort: dstPort,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// ValidateIPs checks if IPs match without allocation
|
|
||||||
func ValidateIPs(connIP IPAddr, pktIP net.IP) bool {
|
|
||||||
if ip4 := pktIP.To4(); ip4 != nil {
|
|
||||||
// Compare IPv4 addresses (last 4 bytes)
|
|
||||||
for i := 0; i < 4; i++ {
|
|
||||||
if connIP[12+i] != ip4[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
// Compare full IPv6 addresses
|
|
||||||
ip6 := pktIP.To16()
|
|
||||||
for i := 0; i < 16; i++ {
|
|
||||||
if connIP[i] != ip6[i] {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
// PreallocatedIPs is a pool of IP byte slices to reduce allocations
|
|
||||||
type PreallocatedIPs struct {
|
|
||||||
sync.Pool
|
|
||||||
}
|
|
||||||
|
|
||||||
// NewPreallocatedIPs creates a new IP pool
|
|
||||||
func NewPreallocatedIPs() *PreallocatedIPs {
|
|
||||||
return &PreallocatedIPs{
|
|
||||||
Pool: sync.Pool{
|
|
||||||
New: func() interface{} {
|
|
||||||
ip := make(net.IP, 16)
|
|
||||||
return &ip
|
|
||||||
},
|
|
||||||
},
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Get retrieves an IP from the pool
|
|
||||||
func (p *PreallocatedIPs) Get() net.IP {
|
|
||||||
return *p.Pool.Get().(*net.IP)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Put returns an IP to the pool
|
|
||||||
func (p *PreallocatedIPs) Put(ip net.IP) {
|
|
||||||
p.Pool.Put(&ip)
|
|
||||||
}
|
|
||||||
|
|
||||||
// copyIP copies an IP address efficiently
|
|
||||||
func copyIP(dst, src net.IP) {
|
|
||||||
if len(src) == 16 {
|
|
||||||
copy(dst, src)
|
|
||||||
} else {
|
|
||||||
// Handle IPv4
|
|
||||||
copy(dst[12:], src.To4())
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,94 +1,66 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/sirupsen/logrus"
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
func BenchmarkIPOperations(b *testing.B) {
|
|
||||||
b.Run("MakeIPAddr", func(b *testing.B) {
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = MakeIPAddr(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("ValidateIPs", func(b *testing.B) {
|
|
||||||
ip1 := net.ParseIP("192.168.1.1")
|
|
||||||
ip2 := net.ParseIP("192.168.1.1")
|
|
||||||
addr := MakeIPAddr(ip1)
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
_ = ValidateIPs(addr, ip2)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IPPool", func(b *testing.B) {
|
|
||||||
pool := NewPreallocatedIPs()
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
ip := pool.Get()
|
|
||||||
pool.Put(ip)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
}
|
|
||||||
|
|
||||||
// Memory pressure tests
|
// Memory pressure tests
|
||||||
func BenchmarkMemoryPressure(b *testing.B) {
|
func BenchmarkMemoryPressure(b *testing.B) {
|
||||||
b.Run("TCPHighLoad", func(b *testing.B) {
|
b.Run("TCPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
srcIdx := i % len(srcIPs)
|
srcIdx := i % len(srcIPs)
|
||||||
dstIdx := (i + 1) % len(dstIPs)
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn)
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
|
||||||
// Simulate some valid inbound packets
|
// Simulate some valid inbound packets
|
||||||
if i%3 == 0 {
|
if i%3 == 0 {
|
||||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck)
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), TCPAck, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("UDPHighLoad", func(b *testing.B) {
|
b.Run("UDPHighLoad", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
// Generate different IPs
|
// Generate different IPs
|
||||||
srcIPs := make([]net.IP, 100)
|
srcIPs := make([]netip.Addr, 100)
|
||||||
dstIPs := make([]net.IP, 100)
|
dstIPs := make([]netip.Addr, 100)
|
||||||
for i := 0; i < 100; i++ {
|
for i := 0; i < 100; i++ {
|
||||||
srcIPs[i] = net.IPv4(192, 168, byte(i/256), byte(i%256))
|
srcIPs[i] = netip.AddrFrom4([4]byte{192, 168, byte(i / 256), byte(i % 256)})
|
||||||
dstIPs[i] = net.IPv4(10, 0, byte(i/256), byte(i%256))
|
dstIPs[i] = netip.AddrFrom4([4]byte{10, 0, byte(i / 256), byte(i % 256)})
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
srcIdx := i % len(srcIPs)
|
srcIdx := i % len(srcIPs)
|
||||||
dstIdx := (i + 1) % len(dstIPs)
|
dstIdx := (i + 1) % len(dstIPs)
|
||||||
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80)
|
tracker.TrackOutbound(srcIPs[srcIdx], dstIPs[dstIdx], uint16(i%65535), 80, 0)
|
||||||
|
|
||||||
// Simulate some valid inbound packets
|
// Simulate some valid inbound packets
|
||||||
if i%3 == 0 {
|
if i%3 == 0 {
|
||||||
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535))
|
tracker.IsValidInbound(dstIPs[dstIdx], srcIPs[srcIdx], 80, uint16(i%65535), 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -2,13 +2,16 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"fmt"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -20,18 +23,20 @@ const (
|
|||||||
|
|
||||||
// ICMPConnKey uniquely identifies an ICMP connection
|
// ICMPConnKey uniquely identifies an ICMP connection
|
||||||
type ICMPConnKey struct {
|
type ICMPConnKey struct {
|
||||||
// Supports both IPv4 and IPv6
|
SrcIP netip.Addr
|
||||||
SrcIP [16]byte
|
DstIP netip.Addr
|
||||||
DstIP [16]byte
|
ID uint16
|
||||||
Sequence uint16 // ICMP sequence number
|
}
|
||||||
ID uint16 // ICMP identifier
|
|
||||||
|
func (i ICMPConnKey) String() string {
|
||||||
|
return fmt.Sprintf("%s -> %s (id %d)", i.SrcIP, i.DstIP, i.ID)
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPConnTrack represents an ICMP connection state
|
// ICMPConnTrack represents an ICMP connection state
|
||||||
type ICMPConnTrack struct {
|
type ICMPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
Sequence uint16
|
ICMPType uint8
|
||||||
ID uint16
|
ICMPCode uint8
|
||||||
}
|
}
|
||||||
|
|
||||||
// ICMPTracker manages ICMP connection states
|
// ICMPTracker manages ICMP connection states
|
||||||
@@ -42,11 +47,11 @@ type ICMPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewICMPTracker creates a new ICMP connection tracker
|
// NewICMPTracker creates a new ICMP connection tracker
|
||||||
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
func NewICMPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *ICMPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultICMPTimeout
|
timeout = DefaultICMPTimeout
|
||||||
}
|
}
|
||||||
@@ -59,67 +64,108 @@ func NewICMPTracker(timeout time.Duration, logger *nblog.Logger) *ICMPTracker {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
cleanupTicker: time.NewTicker(ICMPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine(ctx)
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound ICMP Echo Request
|
func (t *ICMPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, id uint16, direction nftypes.Direction, size int) (ICMPConnKey, bool) {
|
||||||
func (t *ICMPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) {
|
key := ICMPConnKey{
|
||||||
key := makeICMPKey(srcIP, dstIP, id, seq)
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
t.mutex.Lock()
|
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &ICMPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
},
|
|
||||||
ID: id,
|
ID: id,
|
||||||
Sequence: seq,
|
|
||||||
}
|
}
|
||||||
conn.UpdateLastSeen()
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New ICMP connection %v", key)
|
|
||||||
}
|
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
|
||||||
func (t *ICMPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, id uint16, seq uint16, icmpType uint8) bool {
|
|
||||||
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
key := makeICMPKey(dstIP, srcIP, id, seq)
|
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackOutbound records an outbound ICMP connection
|
||||||
|
func (t *ICMPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, size int) {
|
||||||
|
if _, exists := t.updateIfExists(dstIP, srcIP, id, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Egress, nil, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// TrackInbound records an inbound ICMP Echo Request
|
||||||
|
func (t *ICMPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, ruleId []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, id, typecode, nftypes.Ingress, ruleId, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound ICMP connections
|
||||||
|
func (t *ICMPTracker) track(srcIP netip.Addr, dstIP netip.Addr, id uint16, typecode layers.ICMPv4TypeCode, direction nftypes.Direction, ruleId []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, id, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
typ, code := typecode.Type(), typecode.Code()
|
||||||
|
|
||||||
|
// non echo requests don't need tracking
|
||||||
|
if typ != uint8(layers.ICMPv4TypeEchoRequest) {
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendStartEvent(direction, srcIP, dstIP, typ, code, ruleId, size)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &ICMPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s ICMP connection %s type %d code %d", direction, key, typ, code)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleId)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound ICMP Echo Reply matches a tracked request
|
||||||
|
func (t *ICMPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, id uint16, icmpType uint8, size int) bool {
|
||||||
|
if icmpType != uint8(layers.ICMPv4TypeEchoReply) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
key := ICMPConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
ID: id,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
conn.UpdateLastSeen()
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
conn.ID == id &&
|
|
||||||
conn.Sequence == seq
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
||||||
@@ -134,17 +180,18 @@ func (t *ICMPTracker) cleanupRoutine(ctx context.Context) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (t *ICMPTracker) cleanup() {
|
func (t *ICMPTracker) cleanup() {
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Debug("Removed ICMP connection %v (timeout)", key)
|
t.logger.Trace("Removed ICMP connection %s (timeout) [in: %d Pkts/%d B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -154,20 +201,46 @@ func (t *ICMPTracker) Close() {
|
|||||||
t.tickerCancel()
|
t.tickerCancel()
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// makeICMPKey creates an ICMP connection key
|
func (t *ICMPTracker) sendEvent(typ nftypes.Type, conn *ICMPConnTrack, ruleID []byte) {
|
||||||
func makeICMPKey(srcIP net.IP, dstIP net.IP, id uint16, seq uint16) ICMPConnKey {
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
return ICMPConnKey{
|
FlowID: conn.FlowId,
|
||||||
SrcIP: MakeIPAddr(srcIP),
|
Type: typ,
|
||||||
DstIP: MakeIPAddr(dstIP),
|
RuleID: ruleID,
|
||||||
ID: id,
|
Direction: conn.Direction,
|
||||||
Sequence: seq,
|
Protocol: nftypes.ICMP, // TODO: adjust for IPv6/icmpv6
|
||||||
}
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
ICMPType: conn.ICMPType,
|
||||||
|
ICMPCode: conn.ICMPCode,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *ICMPTracker) sendStartEvent(direction nftypes.Direction, srcIP netip.Addr, dstIP netip.Addr, typ uint8, code uint8, ruleID []byte, size int) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeStart,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: direction,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
ICMPType: typ,
|
||||||
|
ICMPCode: code,
|
||||||
|
}
|
||||||
|
if direction == nftypes.Ingress {
|
||||||
|
fields.RxPackets = 1
|
||||||
|
fields.RxBytes = uint64(size)
|
||||||
|
} else {
|
||||||
|
fields.TxPackets = 1
|
||||||
|
fields.TxBytes = uint64(size)
|
||||||
|
}
|
||||||
|
t.flowLogger.StoreEvent(fields)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,39 +1,39 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
)
|
)
|
||||||
|
|
||||||
func BenchmarkICMPTracker(b *testing.B) {
|
func BenchmarkICMPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), uint16(i%65535))
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 0, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewICMPTracker(DefaultICMPTimeout, logger)
|
tracker := NewICMPTracker(DefaultICMPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), uint16(i))
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 0, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), uint16(i%1000), 0)
|
tracker.IsValidInbound(dstIP, srcIP, uint16(i%1000), 0, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -4,12 +4,15 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -20,11 +23,11 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TCPSyn uint8 = 0x02
|
|
||||||
TCPAck uint8 = 0x10
|
|
||||||
TCPFin uint8 = 0x01
|
TCPFin uint8 = 0x01
|
||||||
|
TCPSyn uint8 = 0x02
|
||||||
TCPRst uint8 = 0x04
|
TCPRst uint8 = 0x04
|
||||||
TCPPush uint8 = 0x08
|
TCPPush uint8 = 0x08
|
||||||
|
TCPAck uint8 = 0x10
|
||||||
TCPUrg uint8 = 0x20
|
TCPUrg uint8 = 0x20
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -38,7 +41,36 @@ const (
|
|||||||
)
|
)
|
||||||
|
|
||||||
// TCPState represents the state of a TCP connection
|
// TCPState represents the state of a TCP connection
|
||||||
type TCPState int
|
type TCPState int32
|
||||||
|
|
||||||
|
func (s TCPState) String() string {
|
||||||
|
switch s {
|
||||||
|
case TCPStateNew:
|
||||||
|
return "New"
|
||||||
|
case TCPStateSynSent:
|
||||||
|
return "SYN Sent"
|
||||||
|
case TCPStateSynReceived:
|
||||||
|
return "SYN Received"
|
||||||
|
case TCPStateEstablished:
|
||||||
|
return "Established"
|
||||||
|
case TCPStateFinWait1:
|
||||||
|
return "FIN Wait 1"
|
||||||
|
case TCPStateFinWait2:
|
||||||
|
return "FIN Wait 2"
|
||||||
|
case TCPStateClosing:
|
||||||
|
return "Closing"
|
||||||
|
case TCPStateTimeWait:
|
||||||
|
return "Time Wait"
|
||||||
|
case TCPStateCloseWait:
|
||||||
|
return "Close Wait"
|
||||||
|
case TCPStateLastAck:
|
||||||
|
return "Last ACK"
|
||||||
|
case TCPStateClosed:
|
||||||
|
return "Closed"
|
||||||
|
default:
|
||||||
|
return "Unknown"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
const (
|
const (
|
||||||
TCPStateNew TCPState = iota
|
TCPStateNew TCPState = iota
|
||||||
@@ -54,30 +86,38 @@ const (
|
|||||||
TCPStateClosed
|
TCPStateClosed
|
||||||
)
|
)
|
||||||
|
|
||||||
// TCPConnKey uniquely identifies a TCP connection
|
|
||||||
type TCPConnKey struct {
|
|
||||||
SrcIP [16]byte
|
|
||||||
DstIP [16]byte
|
|
||||||
SrcPort uint16
|
|
||||||
DstPort uint16
|
|
||||||
}
|
|
||||||
|
|
||||||
// TCPConnTrack represents a TCP connection state
|
// TCPConnTrack represents a TCP connection state
|
||||||
type TCPConnTrack struct {
|
type TCPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
State TCPState
|
SourcePort uint16
|
||||||
established atomic.Bool
|
DestPort uint16
|
||||||
sync.RWMutex
|
state atomic.Int32
|
||||||
|
tombstone atomic.Bool
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsEstablished safely checks if connection is established
|
// GetState safely retrieves the current state
|
||||||
func (t *TCPConnTrack) IsEstablished() bool {
|
func (t *TCPConnTrack) GetState() TCPState {
|
||||||
return t.established.Load()
|
return TCPState(t.state.Load())
|
||||||
}
|
}
|
||||||
|
|
||||||
// SetEstablished safely sets the established state
|
// SetState safely updates the current state
|
||||||
func (t *TCPConnTrack) SetEstablished(state bool) {
|
func (t *TCPConnTrack) SetState(state TCPState) {
|
||||||
t.established.Store(state)
|
t.state.Store(int32(state))
|
||||||
|
}
|
||||||
|
|
||||||
|
// CompareAndSwapState atomically changes the state from old to new if current == old
|
||||||
|
func (t *TCPConnTrack) CompareAndSwapState(old, newState TCPState) bool {
|
||||||
|
return t.state.CompareAndSwap(int32(old), int32(newState))
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsTombstone safely checks if the connection is marked for deletion
|
||||||
|
func (t *TCPConnTrack) IsTombstone() bool {
|
||||||
|
return t.tombstone.Load()
|
||||||
|
}
|
||||||
|
|
||||||
|
// SetTombstone safely marks the connection for deletion
|
||||||
|
func (t *TCPConnTrack) SetTombstone() {
|
||||||
|
t.tombstone.Store(true)
|
||||||
}
|
}
|
||||||
|
|
||||||
// TCPTracker manages TCP connection states
|
// TCPTracker manages TCP connection states
|
||||||
@@ -88,11 +128,18 @@ type TCPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
timeout time.Duration
|
timeout time.Duration
|
||||||
ipPool *PreallocatedIPs
|
waitTimeout time.Duration
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewTCPTracker creates a new TCP connection tracker
|
// NewTCPTracker creates a new TCP connection tracker
|
||||||
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
func NewTCPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *TCPTracker {
|
||||||
|
waitTimeout := TimeWaitTimeout
|
||||||
|
if timeout == 0 {
|
||||||
|
timeout = DefaultTCPTimeout
|
||||||
|
} else {
|
||||||
|
waitTimeout = timeout / 45
|
||||||
|
}
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
|
|
||||||
@@ -102,179 +149,211 @@ func NewTCPTracker(timeout time.Duration, logger *nblog.Logger) *TCPTracker {
|
|||||||
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
cleanupTicker: time.NewTicker(TCPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
ipPool: NewPreallocatedIPs(),
|
waitTimeout: waitTimeout,
|
||||||
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine(ctx)
|
go tracker.cleanupRoutine(ctx)
|
||||||
return tracker
|
return tracker
|
||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound processes an outbound TCP packet and updates connection state
|
func (t *TCPTracker) updateIfExists(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
func (t *TCPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) {
|
key := ConnKey{
|
||||||
// Create key before lock
|
SrcIP: srcIP,
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
t.mutex.Lock()
|
DstPort: dstPort,
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
// Use preallocated IPs
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &TCPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
SourcePort: srcPort,
|
|
||||||
DestPort: dstPort,
|
|
||||||
},
|
|
||||||
State: TCPStateNew,
|
|
||||||
}
|
}
|
||||||
conn.UpdateLastSeen()
|
|
||||||
conn.established.Store(false)
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New TCP connection: %s:%d -> %s:%d", srcIP, srcPort, dstIP, dstPort)
|
|
||||||
}
|
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
// Lock individual connection for state update
|
|
||||||
conn.Lock()
|
|
||||||
t.updateState(conn, flags, true)
|
|
||||||
conn.Unlock()
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
|
||||||
func (t *TCPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16, flags uint8) bool {
|
|
||||||
if !isValidFlagCombination(flags) {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
return false
|
t.updateState(key, conn, flags, direction, size)
|
||||||
|
return key, true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Handle RST packets
|
return key, false
|
||||||
if flags&TCPRst != 0 {
|
|
||||||
conn.Lock()
|
|
||||||
if conn.IsEstablished() || conn.State == TCPStateSynSent || conn.State == TCPStateSynReceived {
|
|
||||||
conn.State = TCPStateClosed
|
|
||||||
conn.SetEstablished(false)
|
|
||||||
conn.Unlock()
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
conn.Unlock()
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
conn.Lock()
|
|
||||||
t.updateState(conn, flags, false)
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
isEstablished := conn.IsEstablished()
|
|
||||||
isValidState := t.isValidStateForFlags(conn.State, flags)
|
|
||||||
conn.Unlock()
|
|
||||||
|
|
||||||
return isEstablished || isValidState
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// updateState updates the TCP connection state based on flags
|
// TrackOutbound records an outbound TCP connection
|
||||||
func (t *TCPTracker) updateState(conn *TCPConnTrack, flags uint8, isOutbound bool) {
|
func (t *TCPTracker) TrackOutbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) {
|
||||||
// Handle RST flag specially - it always causes transition to closed
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, flags, nftypes.Egress, size); !exists {
|
||||||
if flags&TCPRst != 0 {
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
conn.State = TCPStateClosed
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Egress, nil, size)
|
||||||
conn.SetEstablished(false)
|
}
|
||||||
|
}
|
||||||
|
|
||||||
t.logger.Trace("TCP connection reset: %s:%d -> %s:%d",
|
// TrackInbound processes an inbound TCP packet and updates connection state
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
func (t *TCPTracker) TrackInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, ruleID []byte, size int) {
|
||||||
|
t.track(srcIP, dstIP, srcPort, dstPort, flags, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *TCPTracker) track(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, flags, direction, size)
|
||||||
|
if exists || flags&TCPSyn == 0 {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
switch conn.State {
|
conn := &TCPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
conn.tombstone.Store(false)
|
||||||
|
conn.state.Store(int32(TCPStateNew))
|
||||||
|
|
||||||
|
t.logger.Trace("New %s TCP connection: %s", direction, key)
|
||||||
|
t.updateState(key, conn, flags, direction, size)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound TCP packet matches a tracked connection
|
||||||
|
func (t *TCPTracker) IsValidInbound(srcIP, dstIP netip.Addr, srcPort, dstPort uint16, flags uint8, size int) bool {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.IsTombstone() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
currentState := conn.GetState()
|
||||||
|
if !t.isValidStateForFlags(currentState, flags) {
|
||||||
|
t.logger.Warn("TCP state %s is not valid with flags %x for connection %s", currentState, flags, key)
|
||||||
|
// allow all flags for established for now
|
||||||
|
if currentState == TCPStateEstablished {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
t.updateState(key, conn, flags, nftypes.Ingress, size)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
// updateState updates the TCP connection state based on flags
|
||||||
|
func (t *TCPTracker) updateState(key ConnKey, conn *TCPConnTrack, flags uint8, packetDir nftypes.Direction, size int) {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(packetDir, size)
|
||||||
|
|
||||||
|
currentState := conn.GetState()
|
||||||
|
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
if conn.CompareAndSwapState(currentState, TCPStateClosed) {
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.logger.Trace("TCP connection reset: %s (dir: %s) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||||
|
key, packetDir, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
var newState TCPState
|
||||||
|
switch currentState {
|
||||||
case TCPStateNew:
|
case TCPStateNew:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck == 0 {
|
||||||
conn.State = TCPStateSynSent
|
if conn.Direction == nftypes.Egress {
|
||||||
|
newState = TCPStateSynSent
|
||||||
|
} else {
|
||||||
|
newState = TCPStateSynReceived
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateSynSent:
|
case TCPStateSynSent:
|
||||||
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
if flags&TCPSyn != 0 && flags&TCPAck != 0 {
|
||||||
if isOutbound {
|
if packetDir != conn.Direction {
|
||||||
conn.State = TCPStateSynReceived
|
newState = TCPStateEstablished
|
||||||
} else {
|
} else {
|
||||||
// Simultaneous open
|
// Simultaneous open
|
||||||
conn.State = TCPStateEstablished
|
newState = TCPStateSynReceived
|
||||||
conn.SetEstablished(true)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateSynReceived:
|
case TCPStateSynReceived:
|
||||||
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
if flags&TCPAck != 0 && flags&TCPSyn == 0 {
|
||||||
conn.State = TCPStateEstablished
|
if packetDir == conn.Direction {
|
||||||
conn.SetEstablished(true)
|
newState = TCPStateEstablished
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateEstablished:
|
case TCPStateEstablished:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
if isOutbound {
|
if packetDir == conn.Direction {
|
||||||
conn.State = TCPStateFinWait1
|
newState = TCPStateFinWait1
|
||||||
} else {
|
} else {
|
||||||
conn.State = TCPStateCloseWait
|
newState = TCPStateCloseWait
|
||||||
}
|
}
|
||||||
conn.SetEstablished(false)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait1:
|
case TCPStateFinWait1:
|
||||||
|
if packetDir != conn.Direction {
|
||||||
switch {
|
switch {
|
||||||
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
case flags&TCPFin != 0 && flags&TCPAck != 0:
|
||||||
// Simultaneous close - both sides sent FIN
|
newState = TCPStateClosing
|
||||||
conn.State = TCPStateClosing
|
|
||||||
case flags&TCPFin != 0:
|
case flags&TCPFin != 0:
|
||||||
conn.State = TCPStateFinWait2
|
newState = TCPStateClosing
|
||||||
case flags&TCPAck != 0:
|
case flags&TCPAck != 0:
|
||||||
conn.State = TCPStateFinWait2
|
newState = TCPStateFinWait2
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateFinWait2:
|
case TCPStateFinWait2:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
newState = TCPStateTimeWait
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateClosing:
|
case TCPStateClosing:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateTimeWait
|
newState = TCPStateTimeWait
|
||||||
// Keep established = false from previous state
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection closed (simultaneous) - %s:%d -> %s:%d",
|
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateCloseWait:
|
case TCPStateCloseWait:
|
||||||
if flags&TCPFin != 0 {
|
if flags&TCPFin != 0 {
|
||||||
conn.State = TCPStateLastAck
|
newState = TCPStateLastAck
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
if flags&TCPAck != 0 {
|
if flags&TCPAck != 0 {
|
||||||
conn.State = TCPStateClosed
|
newState = TCPStateClosed
|
||||||
|
}
|
||||||
t.logger.Trace("TCP connection gracefully closed: %s:%d -> %s:%d",
|
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
case TCPStateTimeWait:
|
if newState != 0 && conn.CompareAndSwapState(currentState, newState) {
|
||||||
// Stay in TIME-WAIT for 2MSL before transitioning to closed
|
t.logger.Trace("TCP connection %s transitioned from %s to %s (dir: %s)", key, currentState, newState, packetDir)
|
||||||
// This is handled by the cleanup routine
|
|
||||||
|
|
||||||
t.logger.Trace("TCP connection completed - %s:%d -> %s:%d",
|
switch newState {
|
||||||
conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
case TCPStateTimeWait:
|
||||||
|
t.logger.Trace("TCP connection %s completed [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
|
||||||
|
case TCPStateClosed:
|
||||||
|
conn.SetTombstone()
|
||||||
|
t.logger.Trace("TCP connection %s closed gracefully [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -283,18 +362,22 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
|||||||
if !isValidFlagCombination(flags) {
|
if !isValidFlagCombination(flags) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
if flags&TCPRst != 0 {
|
||||||
|
if state == TCPStateSynSent {
|
||||||
|
return flags&TCPAck != 0
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
switch state {
|
switch state {
|
||||||
case TCPStateNew:
|
case TCPStateNew:
|
||||||
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
return flags&TCPSyn != 0 && flags&TCPAck == 0
|
||||||
case TCPStateSynSent:
|
case TCPStateSynSent:
|
||||||
|
// TODO: support simultaneous open
|
||||||
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
return flags&TCPSyn != 0 && flags&TCPAck != 0
|
||||||
case TCPStateSynReceived:
|
case TCPStateSynReceived:
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
case TCPStateEstablished:
|
case TCPStateEstablished:
|
||||||
if flags&TCPRst != 0 {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
case TCPStateFinWait1:
|
case TCPStateFinWait1:
|
||||||
return flags&TCPFin != 0 || flags&TCPAck != 0
|
return flags&TCPFin != 0 || flags&TCPAck != 0
|
||||||
@@ -311,9 +394,7 @@ func (t *TCPTracker) isValidStateForFlags(state TCPState, flags uint8) bool {
|
|||||||
case TCPStateLastAck:
|
case TCPStateLastAck:
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
case TCPStateClosed:
|
case TCPStateClosed:
|
||||||
// Accept retransmitted ACKs in closed state
|
// Accept retransmitted ACKs in closed state, the final ACK might be lost and the peer will retransmit their FIN-ACK
|
||||||
// This is important because the final ACK might be lost
|
|
||||||
// and the peer will retransmit their FIN-ACK
|
|
||||||
return flags&TCPAck != 0
|
return flags&TCPAck != 0
|
||||||
}
|
}
|
||||||
return false
|
return false
|
||||||
@@ -337,24 +418,33 @@ func (t *TCPTracker) cleanup() {
|
|||||||
defer t.mutex.Unlock()
|
defer t.mutex.Unlock()
|
||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
|
if conn.IsTombstone() {
|
||||||
|
// Clean up tombstoned connections without sending an event
|
||||||
|
delete(t.connections, key)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
|
||||||
var timeout time.Duration
|
var timeout time.Duration
|
||||||
switch {
|
currentState := conn.GetState()
|
||||||
case conn.State == TCPStateTimeWait:
|
switch currentState {
|
||||||
timeout = TimeWaitTimeout
|
case TCPStateTimeWait:
|
||||||
case conn.IsEstablished():
|
timeout = t.waitTimeout
|
||||||
|
case TCPStateEstablished:
|
||||||
timeout = t.timeout
|
timeout = t.timeout
|
||||||
default:
|
default:
|
||||||
timeout = TCPHandshakeTimeout
|
timeout = TCPHandshakeTimeout
|
||||||
}
|
}
|
||||||
|
|
||||||
lastSeen := conn.GetLastSeen()
|
if conn.timeoutExceeded(timeout) {
|
||||||
if time.Since(lastSeen) > timeout {
|
|
||||||
// Return IPs to pool
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Cleaned up TCP connection: %s:%d -> %s:%d", conn.SourceIP, conn.SourcePort, conn.DestIP, conn.DestPort)
|
t.logger.Trace("Cleaned up timed-out TCP connection %s (%s) [in: %d Pkts/%d, B out: %d Pkts/%d B]",
|
||||||
|
key, conn.GetState(), conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
|
||||||
|
// event already handled by state change
|
||||||
|
if currentState != TCPStateTimeWait {
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -365,10 +455,6 @@ func (t *TCPTracker) Close() {
|
|||||||
|
|
||||||
// Clean up all remaining IPs
|
// Clean up all remaining IPs
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
@@ -386,3 +472,21 @@ func isValidFlagCombination(flags uint8) bool {
|
|||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *TCPTracker) sendEvent(typ nftypes.Type, conn *TCPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
83
client/firewall/uspfilter/conntrack/tcp_bench_test.go
Normal file
@@ -0,0 +1,83 @@
|
|||||||
|
package conntrack
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func BenchmarkTCPTracker(b *testing.B) {
|
||||||
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
// Pre-populate some connections
|
||||||
|
for i := 0; i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck|TCPSyn, 0)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConcurrentAccess", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
i := 0
|
||||||
|
for pb.Next() {
|
||||||
|
if i%2 == 0 {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn, 0)
|
||||||
|
} else {
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck|TCPSyn, 0)
|
||||||
|
}
|
||||||
|
i++
|
||||||
|
}
|
||||||
|
})
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// Benchmark connection cleanup
|
||||||
|
func BenchmarkCleanup(b *testing.B) {
|
||||||
|
b.Run("TCPCleanup", func(b *testing.B) {
|
||||||
|
tracker := NewTCPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
// Pre-populate with expired connections
|
||||||
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
for i := 0; i < 10000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Wait for connections to expire
|
||||||
|
time.Sleep(200 * time.Millisecond)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
tracker.cleanup()
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
@@ -1,19 +1,20 @@
|
|||||||
package conntrack
|
package conntrack
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
)
|
)
|
||||||
|
|
||||||
func TestTCPStateMachine(t *testing.T) {
|
func TestTCPStateMachine(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@@ -58,7 +59,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags)
|
isValid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, tt.flags, 0)
|
||||||
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
require.Equal(t, !tt.wantDrop, isValid, tt.desc)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -76,17 +77,17 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
// Send initial SYN
|
// Send initial SYN
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
// Receive SYN-ACK
|
// Receive SYN-ACK
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
require.True(t, valid, "SYN-ACK should be allowed")
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
// Send ACK
|
// Send ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
// Test data transfer
|
// Test data transfer
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 0)
|
||||||
require.True(t, valid, "Data should be allowed after handshake")
|
require.True(t, valid, "Data should be allowed after handshake")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -99,18 +100,18 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Send FIN
|
// Send FIN
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
|
||||||
// Receive ACK for FIN
|
// Receive ACK for FIN
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
require.True(t, valid, "ACK for FIN should be allowed")
|
require.True(t, valid, "ACK for FIN should be allowed")
|
||||||
|
|
||||||
// Receive FIN from other side
|
// Receive FIN from other side
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
require.True(t, valid, "FIN should be allowed")
|
require.True(t, valid, "FIN should be allowed")
|
||||||
|
|
||||||
// Send final ACK
|
// Send final ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -122,11 +123,8 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Receive RST
|
// Receive RST
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
require.True(t, valid, "RST should be allowed for established connection")
|
require.True(t, valid, "RST should be allowed for established connection")
|
||||||
|
|
||||||
// Connection is logically dead but we don't enforce blocking subsequent packets
|
|
||||||
// The connection will be cleaned up by timeout
|
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -138,13 +136,13 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
// Both sides send FIN+ACK
|
// Both sides send FIN+ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
require.True(t, valid, "Simultaneous FIN should be allowed")
|
require.True(t, valid, "Simultaneous FIN should be allowed")
|
||||||
|
|
||||||
// Both sides send final ACK
|
// Both sides send final ACK
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck)
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
require.True(t, valid, "Final ACKs should be allowed")
|
require.True(t, valid, "Final ACKs should be allowed")
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
@@ -154,7 +152,7 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker = NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
tt.test(t)
|
tt.test(t)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -162,11 +160,11 @@ func TestTCPStateMachine(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestRSTHandling(t *testing.T) {
|
func TestRSTHandling(t *testing.T) {
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("100.64.0.1")
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
dstIP := net.ParseIP("100.64.0.2")
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(80)
|
dstPort := uint16(80)
|
||||||
|
|
||||||
@@ -181,12 +179,12 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
name: "RST in established",
|
name: "RST in established",
|
||||||
setupState: func() {
|
setupState: func() {
|
||||||
// Establish connection first
|
// Establish connection first
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
},
|
},
|
||||||
sendRST: func() {
|
sendRST: func() {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
},
|
},
|
||||||
wantValid: true,
|
wantValid: true,
|
||||||
desc: "Should accept RST for established connection",
|
desc: "Should accept RST for established connection",
|
||||||
@@ -195,7 +193,7 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
name: "RST without connection",
|
name: "RST without connection",
|
||||||
setupState: func() {},
|
setupState: func() {},
|
||||||
sendRST: func() {
|
sendRST: func() {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst)
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
},
|
},
|
||||||
wantValid: false,
|
wantValid: false,
|
||||||
desc: "Should reject RST without connection",
|
desc: "Should reject RST without connection",
|
||||||
@@ -208,101 +206,455 @@ func TestRSTHandling(t *testing.T) {
|
|||||||
tt.sendRST()
|
tt.sendRST()
|
||||||
|
|
||||||
// Verify connection state is as expected
|
// Verify connection state is as expected
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn := tracker.connections[key]
|
conn := tracker.connections[key]
|
||||||
if tt.wantValid {
|
if tt.wantValid {
|
||||||
require.NotNil(t, conn)
|
require.NotNil(t, conn)
|
||||||
require.Equal(t, TCPStateClosed, conn.State)
|
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||||
require.False(t, conn.IsEstablished())
|
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestTCPRetransmissions(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Test SYN retransmission
|
||||||
|
t.Run("SYN Retransmission", func(t *testing.T) {
|
||||||
|
// Initial SYN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Retransmit SYN (should not affect the state machine)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Verify we're still in SYN-SENT state
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateSynSent, conn.GetState())
|
||||||
|
|
||||||
|
// Complete the handshake
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// Verify we're in ESTABLISHED state
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test ACK retransmission in established state
|
||||||
|
t.Run("ACK Retransmission", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Retransmit ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// State should remain ESTABLISHED
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test FIN retransmission
|
||||||
|
t.Run("FIN Retransmission", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Send FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Retransmit FIN (should not change state)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Receive ACK for FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPDataTransfer(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Data Transfer", func(t *testing.T) {
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Send data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||||
|
|
||||||
|
// Receive ACK for data
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 100)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Receive data
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1500)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Send ACK for received data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
|
|
||||||
|
// State should remain ESTABLISHED
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
assert.Equal(t, uint64(1300), conn.BytesTx.Load())
|
||||||
|
assert.Equal(t, uint64(1700), conn.BytesRx.Load())
|
||||||
|
assert.Equal(t, uint64(4), conn.PacketsTx.Load())
|
||||||
|
assert.Equal(t, uint64(3), conn.PacketsRx.Load())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPHalfClosedConnections(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Test half-closed connection: local end closes, remote end continues sending data
|
||||||
|
t.Run("Local Close, Remote Data", func(t *testing.T) {
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Send FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateFinWait1, conn.GetState())
|
||||||
|
|
||||||
|
// Receive ACK for FIN
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateFinWait2, conn.GetState())
|
||||||
|
|
||||||
|
// Remote end can still send data
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPPush|TCPAck, 1000)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// We can still ACK their data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// Receive FIN from remote end
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// Send final ACK
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
// State should remain TIME-WAIT (waiting for possible retransmissions)
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
})
|
||||||
|
|
||||||
|
// Test half-closed connection: remote end closes, local end continues sending data
|
||||||
|
t.Run("Remote Close, Local Data", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
// Establish connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Receive FIN from remote
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateCloseWait, conn.GetState())
|
||||||
|
|
||||||
|
// We can still send data
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPPush|TCPAck, 1000)
|
||||||
|
|
||||||
|
// Remote can still ACK our data
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
|
||||||
|
// Send our FIN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
require.Equal(t, TCPStateLastAck, conn.GetState())
|
||||||
|
|
||||||
|
// Receive final ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
require.True(t, valid)
|
||||||
|
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPAbnormalSequences(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Test handling of unsolicited RST in various states
|
||||||
|
t.Run("Unsolicited RST in SYN-SENT", func(t *testing.T) {
|
||||||
|
// Send SYN
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 0)
|
||||||
|
|
||||||
|
// Receive unsolicited RST (without proper ACK)
|
||||||
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst, 0)
|
||||||
|
require.False(t, valid, "RST without proper ACK in SYN-SENT should be rejected")
|
||||||
|
|
||||||
|
// Receive RST with proper ACK
|
||||||
|
valid = tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPRst|TCPAck, 0)
|
||||||
|
require.True(t, valid, "RST with proper ACK in SYN-SENT should be accepted")
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.Equal(t, TCPStateClosed, conn.GetState())
|
||||||
|
require.True(t, conn.IsTombstone())
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPTimeoutHandling(t *testing.T) {
|
||||||
|
// Create tracker with a very short timeout for testing
|
||||||
|
shortTimeout := 100 * time.Millisecond
|
||||||
|
tracker := NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
t.Run("Connection Timeout", func(t *testing.T) {
|
||||||
|
// Establish a connection
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
// Get connection object
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
|
||||||
|
// Wait for the connection to timeout
|
||||||
|
time.Sleep(2 * shortTimeout)
|
||||||
|
|
||||||
|
// Force cleanup
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
// Connection should be removed
|
||||||
|
_, exists := tracker.connections[key]
|
||||||
|
require.False(t, exists, "Connection should be removed after timeout")
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("TIME_WAIT Timeout", func(t *testing.T) {
|
||||||
|
tracker = NewTCPTracker(shortTimeout, logger, flowLogger)
|
||||||
|
|
||||||
|
establishConnection(t, tracker, srcIP, dstIP, srcPort, dstPort)
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
|
||||||
|
// Complete the connection close to enter TIME_WAIT
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPAck, 0)
|
||||||
|
tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPFin|TCPAck, 0)
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 0)
|
||||||
|
|
||||||
|
require.Equal(t, TCPStateTimeWait, conn.GetState())
|
||||||
|
|
||||||
|
// TIME_WAIT should have its own timeout value (usually 2*MSL)
|
||||||
|
// For the test, we're using a short timeout
|
||||||
|
time.Sleep(2 * shortTimeout)
|
||||||
|
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
// Connection should be removed
|
||||||
|
_, exists := tracker.connections[key]
|
||||||
|
require.False(t, exists, "Connection should be removed after TIME_WAIT timeout")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestSynFlood(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
srcIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
dstIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
basePort := uint16(10000)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
// Create a large number of SYN packets to simulate a SYN flood
|
||||||
|
for i := uint16(0); i < 1000; i++ {
|
||||||
|
tracker.TrackOutbound(srcIP, dstIP, basePort+i, dstPort, TCPSyn, 0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check that we're tracking all connections
|
||||||
|
require.Equal(t, 1000, len(tracker.connections))
|
||||||
|
|
||||||
|
// Now simulate SYN timeout
|
||||||
|
var oldConns int
|
||||||
|
tracker.mutex.Lock()
|
||||||
|
for _, conn := range tracker.connections {
|
||||||
|
if conn.GetState() == TCPStateSynSent {
|
||||||
|
// Make the connection appear old
|
||||||
|
conn.lastSeen.Store(time.Now().Add(-TCPHandshakeTimeout - time.Second).UnixNano())
|
||||||
|
oldConns++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
tracker.mutex.Unlock()
|
||||||
|
require.Equal(t, 1000, oldConns)
|
||||||
|
|
||||||
|
// Run cleanup
|
||||||
|
tracker.cleanup()
|
||||||
|
|
||||||
|
// Check that stale connections were cleaned up
|
||||||
|
require.Equal(t, 0, len(tracker.connections))
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTCPInboundInitiatedConnection(t *testing.T) {
|
||||||
|
tracker := NewTCPTracker(DefaultTCPTimeout, logger, flowLogger)
|
||||||
|
defer tracker.Close()
|
||||||
|
|
||||||
|
clientIP := netip.MustParseAddr("100.64.0.1")
|
||||||
|
serverIP := netip.MustParseAddr("100.64.0.2")
|
||||||
|
clientPort := uint16(12345)
|
||||||
|
serverPort := uint16(80)
|
||||||
|
|
||||||
|
// 1. Client sends SYN (we receive it as inbound)
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPSyn, nil, 100)
|
||||||
|
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: clientIP,
|
||||||
|
DstIP: serverIP,
|
||||||
|
SrcPort: clientPort,
|
||||||
|
DstPort: serverPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
tracker.mutex.RLock()
|
||||||
|
conn := tracker.connections[key]
|
||||||
|
tracker.mutex.RUnlock()
|
||||||
|
|
||||||
|
require.NotNil(t, conn)
|
||||||
|
require.Equal(t, TCPStateSynReceived, conn.GetState(), "Connection should be in SYN-RECEIVED state after inbound SYN")
|
||||||
|
|
||||||
|
// 2. Server sends SYN-ACK response
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPSyn|TCPAck, 100)
|
||||||
|
|
||||||
|
// 3. Client sends ACK to complete handshake
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState(), "Connection should be ESTABLISHED after handshake completion")
|
||||||
|
|
||||||
|
// 4. Test data transfer
|
||||||
|
// Client sends data
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPPush|TCPAck, nil, 1000)
|
||||||
|
|
||||||
|
// Server sends ACK for data
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPAck, 100)
|
||||||
|
|
||||||
|
// Server sends data
|
||||||
|
tracker.TrackOutbound(serverIP, clientIP, serverPort, clientPort, TCPPush|TCPAck, 1500)
|
||||||
|
|
||||||
|
// Client sends ACK for data
|
||||||
|
tracker.TrackInbound(clientIP, serverIP, clientPort, serverPort, TCPAck, nil, 100)
|
||||||
|
|
||||||
|
// Verify state and counters
|
||||||
|
require.Equal(t, TCPStateEstablished, conn.GetState())
|
||||||
|
assert.Equal(t, uint64(1300), conn.BytesRx.Load()) // 3 packets * 100 + 1000 data
|
||||||
|
assert.Equal(t, uint64(1700), conn.BytesTx.Load()) // 2 packets * 100 + 1500 data
|
||||||
|
assert.Equal(t, uint64(4), conn.PacketsRx.Load()) // SYN, ACK, Data
|
||||||
|
assert.Equal(t, uint64(3), conn.PacketsTx.Load()) // SYN-ACK, Data
|
||||||
|
}
|
||||||
|
|
||||||
// Helper to establish a TCP connection
|
// Helper to establish a TCP connection
|
||||||
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP net.IP, srcPort, dstPort uint16) {
|
func establishConnection(t *testing.T, tracker *TCPTracker, srcIP, dstIP netip.Addr, srcPort, dstPort uint16) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPSyn, 100)
|
||||||
|
|
||||||
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck)
|
valid := tracker.IsValidInbound(dstIP, srcIP, dstPort, srcPort, TCPSyn|TCPAck, 100)
|
||||||
require.True(t, valid, "SYN-ACK should be allowed")
|
require.True(t, valid, "SYN-ACK should be allowed")
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, TCPAck, 100)
|
||||||
}
|
|
||||||
|
|
||||||
func BenchmarkTCPTracker(b *testing.B) {
|
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
|
||||||
|
|
||||||
// Pre-populate some connections
|
|
||||||
for i := 0; i < 1000; i++ {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
|
||||||
}
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), TCPAck)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
|
|
||||||
b.Run("ConcurrentAccess", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(DefaultTCPTimeout, logger)
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
|
||||||
|
|
||||||
b.RunParallel(func(pb *testing.PB) {
|
|
||||||
i := 0
|
|
||||||
for pb.Next() {
|
|
||||||
if i%2 == 0 {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, TCPSyn)
|
|
||||||
} else {
|
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%65535), TCPAck)
|
|
||||||
}
|
|
||||||
i++
|
|
||||||
}
|
|
||||||
})
|
|
||||||
})
|
|
||||||
}
|
|
||||||
|
|
||||||
// Benchmark connection cleanup
|
|
||||||
func BenchmarkCleanup(b *testing.B) {
|
|
||||||
b.Run("TCPCleanup", func(b *testing.B) {
|
|
||||||
tracker := NewTCPTracker(100*time.Millisecond, logger) // Short timeout for testing
|
|
||||||
defer tracker.Close()
|
|
||||||
|
|
||||||
// Pre-populate with expired connections
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
|
||||||
for i := 0; i < 10000; i++ {
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, TCPSyn)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Wait for connections to expire
|
|
||||||
time.Sleep(200 * time.Millisecond)
|
|
||||||
|
|
||||||
b.ResetTimer()
|
|
||||||
for i := 0; i < b.N; i++ {
|
|
||||||
tracker.cleanup()
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,11 +2,14 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -19,6 +22,8 @@ const (
|
|||||||
// UDPConnTrack represents a UDP connection state
|
// UDPConnTrack represents a UDP connection state
|
||||||
type UDPConnTrack struct {
|
type UDPConnTrack struct {
|
||||||
BaseConnTrack
|
BaseConnTrack
|
||||||
|
SourcePort uint16
|
||||||
|
DestPort uint16
|
||||||
}
|
}
|
||||||
|
|
||||||
// UDPTracker manages UDP connection states
|
// UDPTracker manages UDP connection states
|
||||||
@@ -29,11 +34,11 @@ type UDPTracker struct {
|
|||||||
cleanupTicker *time.Ticker
|
cleanupTicker *time.Ticker
|
||||||
tickerCancel context.CancelFunc
|
tickerCancel context.CancelFunc
|
||||||
mutex sync.RWMutex
|
mutex sync.RWMutex
|
||||||
ipPool *PreallocatedIPs
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewUDPTracker creates a new UDP connection tracker
|
// NewUDPTracker creates a new UDP connection tracker
|
||||||
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
func NewUDPTracker(timeout time.Duration, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *UDPTracker {
|
||||||
if timeout == 0 {
|
if timeout == 0 {
|
||||||
timeout = DefaultUDPTimeout
|
timeout = DefaultUDPTimeout
|
||||||
}
|
}
|
||||||
@@ -46,7 +51,7 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
cleanupTicker: time.NewTicker(UDPCleanupInterval),
|
||||||
tickerCancel: cancel,
|
tickerCancel: cancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
go tracker.cleanupRoutine(ctx)
|
go tracker.cleanupRoutine(ctx)
|
||||||
@@ -54,55 +59,88 @@ func NewUDPTracker(timeout time.Duration, logger *nblog.Logger) *UDPTracker {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// TrackOutbound records an outbound UDP connection
|
// TrackOutbound records an outbound UDP connection
|
||||||
func (t *UDPTracker) TrackOutbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) {
|
func (t *UDPTracker) TrackOutbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) {
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
if _, exists := t.updateIfExists(dstIP, srcIP, dstPort, srcPort, nftypes.Egress, size); !exists {
|
||||||
|
// if (inverted direction) conn is not tracked, track this direction
|
||||||
t.mutex.Lock()
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Egress, nil, size)
|
||||||
conn, exists := t.connections[key]
|
|
||||||
if !exists {
|
|
||||||
srcIPCopy := t.ipPool.Get()
|
|
||||||
dstIPCopy := t.ipPool.Get()
|
|
||||||
copyIP(srcIPCopy, srcIP)
|
|
||||||
copyIP(dstIPCopy, dstIP)
|
|
||||||
|
|
||||||
conn = &UDPConnTrack{
|
|
||||||
BaseConnTrack: BaseConnTrack{
|
|
||||||
SourceIP: srcIPCopy,
|
|
||||||
DestIP: dstIPCopy,
|
|
||||||
SourcePort: srcPort,
|
|
||||||
DestPort: dstPort,
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
conn.UpdateLastSeen()
|
|
||||||
t.connections[key] = conn
|
|
||||||
|
|
||||||
t.logger.Trace("New UDP connection: %v", conn)
|
|
||||||
}
|
|
||||||
t.mutex.Unlock()
|
|
||||||
|
|
||||||
conn.UpdateLastSeen()
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// IsValidInbound checks if an inbound packet matches a tracked connection
|
// TrackInbound records an inbound UDP connection
|
||||||
func (t *UDPTracker) IsValidInbound(srcIP net.IP, dstIP net.IP, srcPort uint16, dstPort uint16) bool {
|
func (t *UDPTracker) TrackInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, ruleID []byte, size int) {
|
||||||
key := makeConnKey(dstIP, srcIP, dstPort, srcPort)
|
t.track(srcIP, dstIP, srcPort, dstPort, nftypes.Ingress, ruleID, size)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) updateIfExists(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, size int) (ConnKey, bool) {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
|
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
conn, exists := t.connections[key]
|
conn, exists := t.connections[key]
|
||||||
t.mutex.RUnlock()
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
if !exists {
|
if exists {
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
return key, true
|
||||||
|
}
|
||||||
|
|
||||||
|
return key, false
|
||||||
|
}
|
||||||
|
|
||||||
|
// track is the common implementation for tracking both inbound and outbound connections
|
||||||
|
func (t *UDPTracker) track(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, direction nftypes.Direction, ruleID []byte, size int) {
|
||||||
|
key, exists := t.updateIfExists(srcIP, dstIP, srcPort, dstPort, direction, size)
|
||||||
|
if exists {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
conn := &UDPConnTrack{
|
||||||
|
BaseConnTrack: BaseConnTrack{
|
||||||
|
FlowId: uuid.New(),
|
||||||
|
Direction: direction,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
},
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
}
|
||||||
|
conn.UpdateLastSeen()
|
||||||
|
conn.UpdateCounters(direction, size)
|
||||||
|
|
||||||
|
t.mutex.Lock()
|
||||||
|
t.connections[key] = conn
|
||||||
|
t.mutex.Unlock()
|
||||||
|
|
||||||
|
t.logger.Trace("New %s UDP connection: %s", direction, key)
|
||||||
|
t.sendEvent(nftypes.TypeStart, conn, ruleID)
|
||||||
|
}
|
||||||
|
|
||||||
|
// IsValidInbound checks if an inbound packet matches a tracked connection
|
||||||
|
func (t *UDPTracker) IsValidInbound(srcIP netip.Addr, dstIP netip.Addr, srcPort uint16, dstPort uint16, size int) bool {
|
||||||
|
key := ConnKey{
|
||||||
|
SrcIP: dstIP,
|
||||||
|
DstIP: srcIP,
|
||||||
|
SrcPort: dstPort,
|
||||||
|
DstPort: srcPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
t.mutex.RLock()
|
||||||
|
conn, exists := t.connections[key]
|
||||||
|
t.mutex.RUnlock()
|
||||||
|
|
||||||
|
if !exists || conn.timeoutExceeded(t.timeout) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
conn.UpdateLastSeen()
|
||||||
return false
|
conn.UpdateCounters(nftypes.Ingress, size)
|
||||||
}
|
|
||||||
|
|
||||||
return ValidateIPs(MakeIPAddr(srcIP), conn.DestIP) &&
|
return true
|
||||||
ValidateIPs(MakeIPAddr(dstIP), conn.SourceIP) &&
|
|
||||||
conn.DestPort == srcPort &&
|
|
||||||
conn.SourcePort == dstPort
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// cleanupRoutine periodically removes stale connections
|
// cleanupRoutine periodically removes stale connections
|
||||||
@@ -125,11 +163,11 @@ func (t *UDPTracker) cleanup() {
|
|||||||
|
|
||||||
for key, conn := range t.connections {
|
for key, conn := range t.connections {
|
||||||
if conn.timeoutExceeded(t.timeout) {
|
if conn.timeoutExceeded(t.timeout) {
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
delete(t.connections, key)
|
delete(t.connections, key)
|
||||||
|
|
||||||
t.logger.Trace("Removed UDP connection %v (timeout)", conn)
|
t.logger.Trace("Removed UDP connection %s (timeout) [in: %d Pkts/%d B, out: %d Pkts/%d B]",
|
||||||
|
key, conn.PacketsRx.Load(), conn.BytesRx.Load(), conn.PacketsTx.Load(), conn.BytesTx.Load())
|
||||||
|
t.sendEvent(nftypes.TypeEnd, conn, nil)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -139,29 +177,44 @@ func (t *UDPTracker) Close() {
|
|||||||
t.tickerCancel()
|
t.tickerCancel()
|
||||||
|
|
||||||
t.mutex.Lock()
|
t.mutex.Lock()
|
||||||
for _, conn := range t.connections {
|
|
||||||
t.ipPool.Put(conn.SourceIP)
|
|
||||||
t.ipPool.Put(conn.DestIP)
|
|
||||||
}
|
|
||||||
t.connections = nil
|
t.connections = nil
|
||||||
t.mutex.Unlock()
|
t.mutex.Unlock()
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetConnection safely retrieves a connection state
|
// GetConnection safely retrieves a connection state
|
||||||
func (t *UDPTracker) GetConnection(srcIP net.IP, srcPort uint16, dstIP net.IP, dstPort uint16) (*UDPConnTrack, bool) {
|
func (t *UDPTracker) GetConnection(srcIP netip.Addr, srcPort uint16, dstIP netip.Addr, dstPort uint16) (*UDPConnTrack, bool) {
|
||||||
t.mutex.RLock()
|
t.mutex.RLock()
|
||||||
defer t.mutex.RUnlock()
|
defer t.mutex.RUnlock()
|
||||||
|
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
conn, exists := t.connections[key]
|
SrcIP: srcIP,
|
||||||
if !exists {
|
DstIP: dstIP,
|
||||||
return nil, false
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
}
|
}
|
||||||
|
conn, exists := t.connections[key]
|
||||||
return conn, true
|
return conn, exists
|
||||||
}
|
}
|
||||||
|
|
||||||
// Timeout returns the configured timeout duration for the tracker
|
// Timeout returns the configured timeout duration for the tracker
|
||||||
func (t *UDPTracker) Timeout() time.Duration {
|
func (t *UDPTracker) Timeout() time.Duration {
|
||||||
return t.timeout
|
return t.timeout
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (t *UDPTracker) sendEvent(typ nftypes.Type, conn *UDPConnTrack, ruleID []byte) {
|
||||||
|
t.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: conn.FlowId,
|
||||||
|
Type: typ,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: conn.Direction,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
SourceIP: conn.SourceIP,
|
||||||
|
DestIP: conn.DestIP,
|
||||||
|
SourcePort: conn.SourcePort,
|
||||||
|
DestPort: conn.DestPort,
|
||||||
|
RxPackets: conn.PacketsRx.Load(),
|
||||||
|
TxPackets: conn.PacketsTx.Load(),
|
||||||
|
RxBytes: conn.BytesRx.Load(),
|
||||||
|
TxBytes: conn.BytesTx.Load(),
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package conntrack
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
@@ -30,7 +30,7 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
|
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
tracker := NewUDPTracker(tt.timeout, logger)
|
tracker := NewUDPTracker(tt.timeout, logger, flowLogger)
|
||||||
assert.NotNil(t, tracker)
|
assert.NotNil(t, tracker)
|
||||||
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
assert.Equal(t, tt.wantTimeout, tracker.timeout)
|
||||||
assert.NotNil(t, tracker.connections)
|
assert.NotNil(t, tracker.connections)
|
||||||
@@ -41,43 +41,48 @@ func TestNewUDPTracker(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
func TestUDPTracker_TrackOutbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
key := makeConnKey(srcIP, dstIP, srcPort, dstPort)
|
key := ConnKey{
|
||||||
|
SrcIP: srcIP,
|
||||||
|
DstIP: dstIP,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
}
|
||||||
conn, exists := tracker.connections[key]
|
conn, exists := tracker.connections[key]
|
||||||
require.True(t, exists)
|
require.True(t, exists)
|
||||||
assert.True(t, conn.SourceIP.Equal(srcIP))
|
assert.True(t, conn.SourceIP.Compare(srcIP) == 0)
|
||||||
assert.True(t, conn.DestIP.Equal(dstIP))
|
assert.True(t, conn.DestIP.Compare(dstIP) == 0)
|
||||||
assert.Equal(t, srcPort, conn.SourcePort)
|
assert.Equal(t, srcPort, conn.SourcePort)
|
||||||
assert.Equal(t, dstPort, conn.DestPort)
|
assert.Equal(t, dstPort, conn.DestPort)
|
||||||
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
assert.WithinDuration(t, time.Now(), conn.GetLastSeen(), 1*time.Second)
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
func TestUDPTracker_IsValidInbound(t *testing.T) {
|
||||||
tracker := NewUDPTracker(1*time.Second, logger)
|
tracker := NewUDPTracker(1*time.Second, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.2")
|
srcIP := netip.MustParseAddr("192.168.1.2")
|
||||||
dstIP := net.ParseIP("192.168.1.3")
|
dstIP := netip.MustParseAddr("192.168.1.3")
|
||||||
srcPort := uint16(12345)
|
srcPort := uint16(12345)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
// Track outbound connection
|
// Track outbound connection
|
||||||
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort)
|
tracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, 0)
|
||||||
|
|
||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
sleep time.Duration
|
sleep time.Duration
|
||||||
@@ -94,7 +99,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
},
|
},
|
||||||
{
|
{
|
||||||
name: "invalid source IP",
|
name: "invalid source IP",
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: srcIP,
|
dstIP: srcIP,
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
@@ -104,7 +109,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
{
|
{
|
||||||
name: "invalid destination IP",
|
name: "invalid destination IP",
|
||||||
srcIP: dstIP,
|
srcIP: dstIP,
|
||||||
dstIP: net.ParseIP("192.168.1.4"),
|
dstIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
srcPort: dstPort,
|
srcPort: dstPort,
|
||||||
dstPort: srcPort,
|
dstPort: srcPort,
|
||||||
sleep: 0,
|
sleep: 0,
|
||||||
@@ -144,7 +149,7 @@ func TestUDPTracker_IsValidInbound(t *testing.T) {
|
|||||||
if tt.sleep > 0 {
|
if tt.sleep > 0 {
|
||||||
time.Sleep(tt.sleep)
|
time.Sleep(tt.sleep)
|
||||||
}
|
}
|
||||||
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort)
|
got := tracker.IsValidInbound(tt.srcIP, tt.dstIP, tt.srcPort, tt.dstPort, 0)
|
||||||
assert.Equal(t, tt.want, got)
|
assert.Equal(t, tt.want, got)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -164,8 +169,8 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
timeout: timeout,
|
timeout: timeout,
|
||||||
cleanupTicker: time.NewTicker(cleanupInterval),
|
cleanupTicker: time.NewTicker(cleanupInterval),
|
||||||
tickerCancel: tickerCancel,
|
tickerCancel: tickerCancel,
|
||||||
ipPool: NewPreallocatedIPs(),
|
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
}
|
}
|
||||||
|
|
||||||
// Start cleanup routine
|
// Start cleanup routine
|
||||||
@@ -173,27 +178,27 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
// Add some connections
|
// Add some connections
|
||||||
connections := []struct {
|
connections := []struct {
|
||||||
srcIP net.IP
|
srcIP netip.Addr
|
||||||
dstIP net.IP
|
dstIP netip.Addr
|
||||||
srcPort uint16
|
srcPort uint16
|
||||||
dstPort uint16
|
dstPort uint16
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.2"),
|
srcIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
dstIP: net.ParseIP("192.168.1.3"),
|
dstIP: netip.MustParseAddr("192.168.1.3"),
|
||||||
srcPort: 12345,
|
srcPort: 12345,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
srcIP: net.ParseIP("192.168.1.4"),
|
srcIP: netip.MustParseAddr("192.168.1.4"),
|
||||||
dstIP: net.ParseIP("192.168.1.5"),
|
dstIP: netip.MustParseAddr("192.168.1.5"),
|
||||||
srcPort: 12346,
|
srcPort: 12346,
|
||||||
dstPort: 53,
|
dstPort: 53,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, conn := range connections {
|
for _, conn := range connections {
|
||||||
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort)
|
tracker.TrackOutbound(conn.srcIP, conn.dstIP, conn.srcPort, conn.dstPort, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify initial connections
|
// Verify initial connections
|
||||||
@@ -215,33 +220,33 @@ func TestUDPTracker_Cleanup(t *testing.T) {
|
|||||||
|
|
||||||
func BenchmarkUDPTracker(b *testing.B) {
|
func BenchmarkUDPTracker(b *testing.B) {
|
||||||
b.Run("TrackOutbound", func(b *testing.B) {
|
b.Run("TrackOutbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i%65535), 80, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|
||||||
b.Run("IsValidInbound", func(b *testing.B) {
|
b.Run("IsValidInbound", func(b *testing.B) {
|
||||||
tracker := NewUDPTracker(DefaultUDPTimeout, logger)
|
tracker := NewUDPTracker(DefaultUDPTimeout, logger, flowLogger)
|
||||||
defer tracker.Close()
|
defer tracker.Close()
|
||||||
|
|
||||||
srcIP := net.ParseIP("192.168.1.1")
|
srcIP := netip.MustParseAddr("192.168.1.1")
|
||||||
dstIP := net.ParseIP("192.168.1.2")
|
dstIP := netip.MustParseAddr("192.168.1.2")
|
||||||
|
|
||||||
// Pre-populate some connections
|
// Pre-populate some connections
|
||||||
for i := 0; i < 1000; i++ {
|
for i := 0; i < 1000; i++ {
|
||||||
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80)
|
tracker.TrackOutbound(srcIP, dstIP, uint16(i), 80, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000))
|
tracker.IsValidInbound(dstIP, srcIP, 80, uint16(i%1000), 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -1,6 +1,8 @@
|
|||||||
package forwarder
|
package forwarder
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"fmt"
|
||||||
|
|
||||||
wgdevice "golang.zx2c4.com/wireguard/device"
|
wgdevice "golang.zx2c4.com/wireguard/device"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
@@ -79,3 +81,10 @@ func (e *endpoint) AddHeader(*stack.PacketBuffer) {
|
|||||||
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
func (e *endpoint) ParseHeader(*stack.PacketBuffer) bool {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type epID stack.TransportEndpointID
|
||||||
|
|
||||||
|
func (i epID) String() string {
|
||||||
|
// src and remote is swapped
|
||||||
|
return fmt.Sprintf("%s:%d -> %s:%d", i.RemoteAddress, i.RemotePort, i.LocalAddress, i.LocalPort)
|
||||||
|
}
|
||||||
|
|||||||
@@ -18,6 +18,7 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/common"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -29,6 +30,7 @@ const (
|
|||||||
|
|
||||||
type Forwarder struct {
|
type Forwarder struct {
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
stack *stack.Stack
|
stack *stack.Stack
|
||||||
endpoint *endpoint
|
endpoint *endpoint
|
||||||
udpForwarder *udpForwarder
|
udpForwarder *udpForwarder
|
||||||
@@ -38,7 +40,7 @@ type Forwarder struct {
|
|||||||
netstack bool
|
netstack bool
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwarder, error) {
|
func New(iface common.IFaceMapper, logger *nblog.Logger, flowLogger nftypes.FlowLogger, netstack bool) (*Forwarder, error) {
|
||||||
s := stack.New(stack.Options{
|
s := stack.New(stack.Options{
|
||||||
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
NetworkProtocols: []stack.NetworkProtocolFactory{ipv4.NewProtocol},
|
||||||
TransportProtocols: []stack.TransportProtocolFactory{
|
TransportProtocols: []stack.TransportProtocolFactory{
|
||||||
@@ -102,9 +104,10 @@ func New(iface common.IFaceMapper, logger *nblog.Logger, netstack bool) (*Forwar
|
|||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &Forwarder{
|
f := &Forwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
stack: s,
|
stack: s,
|
||||||
endpoint: endpoint,
|
endpoint: endpoint,
|
||||||
udpForwarder: newUDPForwarder(mtu, logger),
|
udpForwarder: newUDPForwarder(mtu, logger, flowLogger),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
netstack: netstack,
|
netstack: netstack,
|
||||||
|
|||||||
@@ -3,14 +3,30 @@ package forwarder
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/header"
|
"gvisor.dev/gvisor/pkg/tcpip/header"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleICMP handles ICMP packets from the network stack
|
// handleICMP handles ICMP packets from the network stack
|
||||||
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBufferPtr) bool {
|
||||||
|
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
||||||
|
icmpType := uint8(icmpHdr.Type())
|
||||||
|
icmpCode := uint8(icmpHdr.Code())
|
||||||
|
|
||||||
|
if header.ICMPv4Type(icmpType) == header.ICMPv4EchoReply {
|
||||||
|
// dont process our own replies
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
f.sendICMPEvent(nftypes.TypeStart, flowID, id, icmpType, icmpCode)
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
ctx, cancel := context.WithTimeout(f.ctx, 5*time.Second)
|
||||||
defer cancel()
|
defer cancel()
|
||||||
|
|
||||||
@@ -18,7 +34,7 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
// TODO: support non-root
|
// TODO: support non-root
|
||||||
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
conn, err := lc.ListenPacket(ctx, "ip4:icmp", "0.0.0.0")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Error("Failed to create ICMP socket for %v: %v", id, err)
|
f.logger.Error("Failed to create ICMP socket for %v: %v", epID(id), err)
|
||||||
|
|
||||||
// This will make netstack reply on behalf of the original destination, that's ok for now
|
// This will make netstack reply on behalf of the original destination, that's ok for now
|
||||||
return false
|
return false
|
||||||
@@ -32,47 +48,31 @@ func (f *Forwarder) handleICMP(id stack.TransportEndpointID, pkt stack.PacketBuf
|
|||||||
dstIP := f.determineDialAddr(id.LocalAddress)
|
dstIP := f.determineDialAddr(id.LocalAddress)
|
||||||
dst := &net.IPAddr{IP: dstIP}
|
dst := &net.IPAddr{IP: dstIP}
|
||||||
|
|
||||||
// Get the complete ICMP message (header + data)
|
|
||||||
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
fullPacket := stack.PayloadSince(pkt.TransportHeader())
|
||||||
payload := fullPacket.AsSlice()
|
payload := fullPacket.AsSlice()
|
||||||
|
|
||||||
icmpHdr := header.ICMPv4(pkt.TransportHeader().View().AsSlice())
|
if _, err = conn.WriteTo(payload, dst); err != nil {
|
||||||
|
f.logger.Error("Failed to write ICMP packet for %v: %v", epID(id), err)
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
f.logger.Trace("Forwarded ICMP packet %v type %v code %v",
|
||||||
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
|
||||||
// For Echo Requests, send and handle response
|
// For Echo Requests, send and handle response
|
||||||
switch icmpHdr.Type() {
|
if header.ICMPv4Type(icmpType) == header.ICMPv4Echo {
|
||||||
case header.ICMPv4Echo:
|
f.handleEchoResponse(icmpHdr, conn, id)
|
||||||
return f.handleEchoResponse(icmpHdr, payload, dst, conn, id)
|
f.sendICMPEvent(nftypes.TypeEnd, flowID, id, icmpType, icmpCode)
|
||||||
case header.ICMPv4EchoReply:
|
|
||||||
// dont process our own replies
|
|
||||||
return true
|
|
||||||
default:
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// For other ICMP types (Time Exceeded, Destination Unreachable, etc)
|
// For other ICMP types (Time Exceeded, Destination Unreachable, etc) do nothing
|
||||||
_, err = conn.WriteTo(payload, dst)
|
|
||||||
if err != nil {
|
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
|
||||||
id, icmpHdr.Type(), icmpHdr.Code())
|
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, dst *net.IPAddr, conn net.PacketConn, id stack.TransportEndpointID) bool {
|
func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, conn net.PacketConn, id stack.TransportEndpointID) {
|
||||||
if _, err := conn.WriteTo(payload, dst); err != nil {
|
|
||||||
f.logger.Error("Failed to write ICMP packet for %v: %v", id, err)
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP packet %v type=%v code=%v",
|
|
||||||
id, icmpHdr.Type(), icmpHdr.Code())
|
|
||||||
|
|
||||||
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
if err := conn.SetReadDeadline(time.Now().Add(5 * time.Second)); err != nil {
|
||||||
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
f.logger.Error("Failed to set read deadline for ICMP response: %v", err)
|
||||||
return true
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
response := make([]byte, f.endpoint.mtu)
|
response := make([]byte, f.endpoint.mtu)
|
||||||
@@ -81,7 +81,7 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
|||||||
if !isTimeout(err) {
|
if !isTimeout(err) {
|
||||||
f.logger.Error("Failed to read ICMP response: %v", err)
|
f.logger.Error("Failed to read ICMP response: %v", err)
|
||||||
}
|
}
|
||||||
return true
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ipHdr := make([]byte, header.IPv4MinimumSize)
|
ipHdr := make([]byte, header.IPv4MinimumSize)
|
||||||
@@ -101,9 +101,27 @@ func (f *Forwarder) handleEchoResponse(icmpHdr header.ICMPv4, payload []byte, ds
|
|||||||
|
|
||||||
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
if err := f.InjectIncomingPacket(fullPacket); err != nil {
|
||||||
f.logger.Error("Failed to inject ICMP response: %v", err)
|
f.logger.Error("Failed to inject ICMP response: %v", err)
|
||||||
return true
|
|
||||||
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
f.logger.Trace("Forwarded ICMP echo reply for %v", id)
|
f.logger.Trace("Forwarded ICMP echo reply for %v type %v code %v",
|
||||||
return true
|
epID(id), icmpHdr.Type(), icmpHdr.Code())
|
||||||
|
}
|
||||||
|
|
||||||
|
// sendICMPEvent stores flow events for ICMP packets
|
||||||
|
func (f *Forwarder) sendICMPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, icmpType, icmpCode uint8) {
|
||||||
|
f.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.ICMP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
|
||||||
|
// TODO: get packets/bytes
|
||||||
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -5,24 +5,38 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"io"
|
"io"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
"gvisor.dev/gvisor/pkg/tcpip/transport/tcp"
|
||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
// handleTCP is called by the TCP forwarder for new connections.
|
// handleTCP is called by the TCP forwarder for new connections.
|
||||||
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
||||||
id := r.ID()
|
id := r.ID()
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dialAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
|
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "tcp", dialAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
r.Complete(true)
|
r.Complete(true)
|
||||||
f.logger.Trace("forwarder: dial error for %v: %v", id, err)
|
f.logger.Trace("forwarder: dial error for %v: %v", epID(id), err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -44,12 +58,13 @@ func (f *Forwarder) handleTCP(r *tcp.ForwarderRequest) {
|
|||||||
|
|
||||||
inConn := gonet.NewTCPConn(&wq, ep)
|
inConn := gonet.NewTCPConn(&wq, ep)
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established TCP connection %v", id)
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established TCP connection %v", epID(id))
|
||||||
|
|
||||||
go f.proxyTCP(id, inConn, outConn, ep)
|
go f.proxyTCP(id, inConn, outConn, ep, flowID)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint) {
|
func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn, outConn net.Conn, ep tcpip.Endpoint, flowID uuid.UUID) {
|
||||||
defer func() {
|
defer func() {
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: inConn close error: %v", err)
|
f.logger.Debug("forwarder: inConn close error: %v", err)
|
||||||
@@ -58,6 +73,8 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
f.logger.Debug("forwarder: outConn close error: %v", err)
|
f.logger.Debug("forwarder: outConn close error: %v", err)
|
||||||
}
|
}
|
||||||
ep.Close()
|
ep.Close()
|
||||||
|
|
||||||
|
f.sendTCPEvent(nftypes.TypeEnd, flowID, id, ep)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
// Create context for managing the proxy goroutines
|
// Create context for managing the proxy goroutines
|
||||||
@@ -78,13 +95,38 @@ func (f *Forwarder) proxyTCP(id stack.TransportEndpointID, inConn *gonet.TCPConn
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", id)
|
f.logger.Trace("forwarder: tearing down TCP connection %v due to context done", epID(id))
|
||||||
return
|
return
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil && !isClosedError(err) {
|
if err != nil && !isClosedError(err) {
|
||||||
f.logger.Error("proxyTCP: copy error: %v", err)
|
f.logger.Error("proxyTCP: copy error: %v", err)
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down TCP connection %v", id)
|
f.logger.Trace("forwarder: tearing down TCP connection %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (f *Forwarder) sendTCPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.TCP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcp.Stats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.SegmentsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.SegmentsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,10 +5,12 @@ import (
|
|||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"sync/atomic"
|
"sync/atomic"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/google/uuid"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip"
|
"gvisor.dev/gvisor/pkg/tcpip"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
"gvisor.dev/gvisor/pkg/tcpip/adapters/gonet"
|
||||||
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
"gvisor.dev/gvisor/pkg/tcpip/stack"
|
||||||
@@ -16,6 +18,7 @@ import (
|
|||||||
"gvisor.dev/gvisor/pkg/waiter"
|
"gvisor.dev/gvisor/pkg/waiter"
|
||||||
|
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -28,11 +31,13 @@ type udpPacketConn struct {
|
|||||||
lastSeen atomic.Int64
|
lastSeen atomic.Int64
|
||||||
cancel context.CancelFunc
|
cancel context.CancelFunc
|
||||||
ep tcpip.Endpoint
|
ep tcpip.Endpoint
|
||||||
|
flowID uuid.UUID
|
||||||
}
|
}
|
||||||
|
|
||||||
type udpForwarder struct {
|
type udpForwarder struct {
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
conns map[stack.TransportEndpointID]*udpPacketConn
|
conns map[stack.TransportEndpointID]*udpPacketConn
|
||||||
bufPool sync.Pool
|
bufPool sync.Pool
|
||||||
ctx context.Context
|
ctx context.Context
|
||||||
@@ -44,10 +49,11 @@ type idleConn struct {
|
|||||||
conn *udpPacketConn
|
conn *udpPacketConn
|
||||||
}
|
}
|
||||||
|
|
||||||
func newUDPForwarder(mtu int, logger *nblog.Logger) *udpForwarder {
|
func newUDPForwarder(mtu int, logger *nblog.Logger, flowLogger nftypes.FlowLogger) *udpForwarder {
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
f := &udpForwarder{
|
f := &udpForwarder{
|
||||||
logger: logger,
|
logger: logger,
|
||||||
|
flowLogger: flowLogger,
|
||||||
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
conns: make(map[stack.TransportEndpointID]*udpPacketConn),
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
@@ -72,10 +78,10 @@ func (f *udpForwarder) Stop() {
|
|||||||
for id, conn := range f.conns {
|
for id, conn := range f.conns {
|
||||||
conn.cancel()
|
conn.cancel()
|
||||||
if err := conn.conn.Close(); err != nil {
|
if err := conn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := conn.outConn.Close(); err != nil {
|
if err := conn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
conn.ep.Close()
|
conn.ep.Close()
|
||||||
@@ -106,10 +112,10 @@ func (f *udpForwarder) cleanup() {
|
|||||||
for _, idle := range idleConns {
|
for _, idle := range idleConns {
|
||||||
idle.conn.cancel()
|
idle.conn.cancel()
|
||||||
if err := idle.conn.conn.Close(); err != nil {
|
if err := idle.conn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP conn close error for %v: %v", idle.id, err)
|
f.logger.Debug("forwarder: UDP conn close error for %v: %v", epID(idle.id), err)
|
||||||
}
|
}
|
||||||
if err := idle.conn.outConn.Close(); err != nil {
|
if err := idle.conn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", idle.id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(idle.id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
idle.conn.ep.Close()
|
idle.conn.ep.Close()
|
||||||
@@ -118,7 +124,7 @@ func (f *udpForwarder) cleanup() {
|
|||||||
delete(f.conns, idle.id)
|
delete(f.conns, idle.id)
|
||||||
f.Unlock()
|
f.Unlock()
|
||||||
|
|
||||||
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", idle.id)
|
f.logger.Trace("forwarder: cleaned up idle UDP connection %v", epID(idle.id))
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -137,14 +143,24 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
_, exists := f.udpForwarder.conns[id]
|
_, exists := f.udpForwarder.conns[id]
|
||||||
f.udpForwarder.RUnlock()
|
f.udpForwarder.RUnlock()
|
||||||
if exists {
|
if exists {
|
||||||
f.logger.Trace("forwarder: existing UDP connection for %v", id)
|
f.logger.Trace("forwarder: existing UDP connection for %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
|
flowID := uuid.New()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeStart, flowID, id, nil)
|
||||||
|
var success bool
|
||||||
|
defer func() {
|
||||||
|
if !success {
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, flowID, id, nil)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
dstAddr := fmt.Sprintf("%s:%d", f.determineDialAddr(id.LocalAddress), id.LocalPort)
|
||||||
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
outConn, err := (&net.Dialer{}).DialContext(f.ctx, "udp", dstAddr)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
f.logger.Debug("forwarder: UDP dial error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP dial error for %v: %v", epID(id), err)
|
||||||
// TODO: Send ICMP error message
|
// TODO: Send ICMP error message
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -155,7 +171,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
if epErr != nil {
|
if epErr != nil {
|
||||||
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
f.logger.Debug("forwarder: failed to create UDP endpoint: %v", epErr)
|
||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -168,6 +184,7 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
outConn: outConn,
|
outConn: outConn,
|
||||||
cancel: connCancel,
|
cancel: connCancel,
|
||||||
ep: ep,
|
ep: ep,
|
||||||
|
flowID: flowID,
|
||||||
}
|
}
|
||||||
pConn.updateLastSeen()
|
pConn.updateLastSeen()
|
||||||
|
|
||||||
@@ -177,17 +194,20 @@ func (f *Forwarder) handleUDP(r *udp.ForwarderRequest) {
|
|||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := inConn.Close(); err != nil {
|
if err := inConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := outConn.Close(); err != nil {
|
if err := outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
f.udpForwarder.conns[id] = pConn
|
f.udpForwarder.conns[id] = pConn
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
f.logger.Trace("forwarder: established UDP connection to %v", id)
|
success = true
|
||||||
|
f.logger.Trace("forwarder: established UDP connection %v", epID(id))
|
||||||
|
|
||||||
go f.proxyUDP(connCtx, pConn, id, ep)
|
go f.proxyUDP(connCtx, pConn, id, ep)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -195,10 +215,10 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
defer func() {
|
defer func() {
|
||||||
pConn.cancel()
|
pConn.cancel()
|
||||||
if err := pConn.conn.Close(); err != nil {
|
if err := pConn.conn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP inConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
if err := pConn.outConn.Close(); err != nil {
|
if err := pConn.outConn.Close(); err != nil {
|
||||||
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", id, err)
|
f.logger.Debug("forwarder: UDP outConn close error for %v: %v", epID(id), err)
|
||||||
}
|
}
|
||||||
|
|
||||||
ep.Close()
|
ep.Close()
|
||||||
@@ -206,6 +226,8 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
f.udpForwarder.Lock()
|
f.udpForwarder.Lock()
|
||||||
delete(f.udpForwarder.conns, id)
|
delete(f.udpForwarder.conns, id)
|
||||||
f.udpForwarder.Unlock()
|
f.udpForwarder.Unlock()
|
||||||
|
|
||||||
|
f.sendUDPEvent(nftypes.TypeEnd, pConn.flowID, id, ep)
|
||||||
}()
|
}()
|
||||||
|
|
||||||
errChan := make(chan error, 2)
|
errChan := make(chan error, 2)
|
||||||
@@ -220,17 +242,43 @@ func (f *Forwarder) proxyUDP(ctx context.Context, pConn *udpPacketConn, id stack
|
|||||||
|
|
||||||
select {
|
select {
|
||||||
case <-ctx.Done():
|
case <-ctx.Done():
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", id)
|
f.logger.Trace("forwarder: tearing down UDP connection %v due to context done", epID(id))
|
||||||
return
|
return
|
||||||
case err := <-errChan:
|
case err := <-errChan:
|
||||||
if err != nil && !isClosedError(err) {
|
if err != nil && !isClosedError(err) {
|
||||||
f.logger.Error("proxyUDP: copy error: %v", err)
|
f.logger.Error("proxyUDP: copy error: %v", err)
|
||||||
}
|
}
|
||||||
f.logger.Trace("forwarder: tearing down UDP connection %v", id)
|
f.logger.Trace("forwarder: tearing down UDP connection %v", epID(id))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// sendUDPEvent stores flow events for UDP connections
|
||||||
|
func (f *Forwarder) sendUDPEvent(typ nftypes.Type, flowID uuid.UUID, id stack.TransportEndpointID, ep tcpip.Endpoint) {
|
||||||
|
fields := nftypes.EventFields{
|
||||||
|
FlowID: flowID,
|
||||||
|
Type: typ,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: nftypes.UDP,
|
||||||
|
// TODO: handle ipv6
|
||||||
|
SourceIP: netip.AddrFrom4(id.RemoteAddress.As4()),
|
||||||
|
DestIP: netip.AddrFrom4(id.LocalAddress.As4()),
|
||||||
|
SourcePort: id.RemotePort,
|
||||||
|
DestPort: id.LocalPort,
|
||||||
|
}
|
||||||
|
|
||||||
|
if ep != nil {
|
||||||
|
if tcpStats, ok := ep.Stats().(*tcpip.TransportEndpointStats); ok {
|
||||||
|
// fields are flipped since this is the in conn
|
||||||
|
// TODO: get bytes
|
||||||
|
fields.RxPackets = tcpStats.PacketsSent.Value()
|
||||||
|
fields.TxPackets = tcpStats.PacketsReceived.Value()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.flowLogger.StoreEvent(fields)
|
||||||
|
}
|
||||||
|
|
||||||
func (c *udpPacketConn) updateLastSeen() {
|
func (c *udpPacketConn) updateLastSeen() {
|
||||||
c.lastSeen.Store(time.Now().UnixNano())
|
c.lastSeen.Store(time.Now().UnixNano())
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -13,8 +14,13 @@ import (
|
|||||||
type localIPManager struct {
|
type localIPManager struct {
|
||||||
mu sync.RWMutex
|
mu sync.RWMutex
|
||||||
|
|
||||||
// Use bitmap for IPv4 (32 bits * 2^16 = 256KB memory)
|
// fixed-size high array for upper byte of a IPv4 address
|
||||||
ipv4Bitmap [1 << 16]uint32
|
ipv4Bitmap [256]*ipv4LowBitmap
|
||||||
|
}
|
||||||
|
|
||||||
|
// ipv4LowBitmap is a map for the low 16 bits of a IPv4 address
|
||||||
|
type ipv4LowBitmap struct {
|
||||||
|
bitmap [8192]uint32
|
||||||
}
|
}
|
||||||
|
|
||||||
func newLocalIPManager() *localIPManager {
|
func newLocalIPManager() *localIPManager {
|
||||||
@@ -26,39 +32,59 @@ func (m *localIPManager) setBitmapBit(ip net.IP) {
|
|||||||
if ipv4 == nil {
|
if ipv4 == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
high := uint16(ipv4[0])
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
m.ipv4Bitmap[high] |= 1 << (low % 32)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *localIPManager) checkBitmapBit(ip net.IP) bool {
|
index := low / 32
|
||||||
ipv4 := ip.To4()
|
bit := low % 32
|
||||||
if ipv4 == nil {
|
|
||||||
return false
|
if m.ipv4Bitmap[high] == nil {
|
||||||
|
m.ipv4Bitmap[high] = &ipv4LowBitmap{}
|
||||||
}
|
}
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
m.ipv4Bitmap[high].bitmap[index] |= 1 << bit
|
||||||
return (m.ipv4Bitmap[high] & (1 << (low % 32))) != 0
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processIP(ip net.IP, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
func (m *localIPManager) setBitInBitmap(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
if ipv4 := ip.To4(); ipv4 != nil {
|
||||||
high := (uint16(ipv4[0]) << 8) | uint16(ipv4[1])
|
high := uint16(ipv4[0])
|
||||||
low := (uint16(ipv4[2]) << 8) | uint16(ipv4[3])
|
low := (uint16(ipv4[1]) << 8) | (uint16(ipv4[2]) << 4) | uint16(ipv4[3])
|
||||||
if int(high) >= len(*newIPv4Bitmap) {
|
|
||||||
return fmt.Errorf("invalid IPv4 address: %s", ip)
|
if bitmap[high] == nil {
|
||||||
|
bitmap[high] = &ipv4LowBitmap{}
|
||||||
}
|
}
|
||||||
ipStr := ip.String()
|
|
||||||
|
index := low / 32
|
||||||
|
bit := low % 32
|
||||||
|
bitmap[high].bitmap[index] |= 1 << bit
|
||||||
|
|
||||||
|
ipStr := ipv4.String()
|
||||||
if _, exists := ipv4Set[ipStr]; !exists {
|
if _, exists := ipv4Set[ipStr]; !exists {
|
||||||
ipv4Set[ipStr] = struct{}{}
|
ipv4Set[ipStr] = struct{}{}
|
||||||
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
*ipv4Addresses = append(*ipv4Addresses, ipStr)
|
||||||
newIPv4Bitmap[high] |= 1 << (low % 32)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) checkBitmapBit(ip []byte) bool {
|
||||||
|
high := uint16(ip[0])
|
||||||
|
low := (uint16(ip[1]) << 8) | (uint16(ip[2]) << 4) | uint16(ip[3])
|
||||||
|
|
||||||
|
if m.ipv4Bitmap[high] == nil {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
index := low / 32
|
||||||
|
bit := low % 32
|
||||||
|
return (m.ipv4Bitmap[high].bitmap[index] & (1 << bit)) != 0
|
||||||
|
}
|
||||||
|
|
||||||
|
func (m *localIPManager) processIP(ip net.IP, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) error {
|
||||||
|
m.setBitInBitmap(ip, bitmap, ipv4Set, ipv4Addresses)
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1 << 16]uint32, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
func (m *localIPManager) processInterface(iface net.Interface, bitmap *[256]*ipv4LowBitmap, ipv4Set map[string]struct{}, ipv4Addresses *[]string) {
|
||||||
addrs, err := iface.Addrs()
|
addrs, err := iface.Addrs()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
log.Debugf("get addresses for interface %s failed: %v", iface.Name, err)
|
||||||
@@ -76,7 +102,7 @@ func (m *localIPManager) processInterface(iface net.Interface, newIPv4Bitmap *[1
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.processIP(ip, newIPv4Bitmap, ipv4Set, ipv4Addresses); err != nil {
|
if err := m.processIP(ip, bitmap, ipv4Set, ipv4Addresses); err != nil {
|
||||||
log.Debugf("process IP failed: %v", err)
|
log.Debugf("process IP failed: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -89,14 +115,14 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
|
||||||
var newIPv4Bitmap [1 << 16]uint32
|
var newIPv4Bitmap [256]*ipv4LowBitmap
|
||||||
ipv4Set := make(map[string]struct{})
|
ipv4Set := make(map[string]struct{})
|
||||||
var ipv4Addresses []string
|
var ipv4Addresses []string
|
||||||
|
|
||||||
// 127.0.0.0/8
|
// 127.0.0.0/8
|
||||||
high := uint16(127) << 8
|
newIPv4Bitmap[127] = &ipv4LowBitmap{}
|
||||||
for i := uint16(0); i < 256; i++ {
|
for i := 0; i < 8192; i++ {
|
||||||
newIPv4Bitmap[high|i] = 0xffffffff
|
newIPv4Bitmap[127].bitmap[i] = 0xFFFFFFFF
|
||||||
}
|
}
|
||||||
|
|
||||||
if iface != nil {
|
if iface != nil {
|
||||||
@@ -122,13 +148,13 @@ func (m *localIPManager) UpdateLocalIPs(iface common.IFaceMapper) (err error) {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *localIPManager) IsLocalIP(ip net.IP) bool {
|
func (m *localIPManager) IsLocalIP(ip netip.Addr) bool {
|
||||||
|
if !ip.Is4() {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
m.mu.RLock()
|
m.mu.RLock()
|
||||||
defer m.mu.RUnlock()
|
defer m.mu.RUnlock()
|
||||||
|
|
||||||
if ipv4 := ip.To4(); ipv4 != nil {
|
return m.checkBitmapBit(ip.AsSlice())
|
||||||
return m.checkBitmapBit(ipv4)
|
|
||||||
}
|
|
||||||
|
|
||||||
return false
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"testing"
|
"testing"
|
||||||
|
|
||||||
"github.com/stretchr/testify/require"
|
"github.com/stretchr/testify/require"
|
||||||
@@ -13,7 +14,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
tests := []struct {
|
tests := []struct {
|
||||||
name string
|
name string
|
||||||
setupAddr wgaddr.Address
|
setupAddr wgaddr.Address
|
||||||
testIP net.IP
|
testIP netip.Addr
|
||||||
expected bool
|
expected bool
|
||||||
}{
|
}{
|
||||||
{
|
{
|
||||||
@@ -25,7 +26,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.2"),
|
testIP: netip.MustParseAddr("127.0.0.2"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -37,7 +38,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.0.0.1"),
|
testIP: netip.MustParseAddr("127.0.0.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -49,7 +50,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("127.255.255.255"),
|
testIP: netip.MustParseAddr("127.255.255.255"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -61,7 +62,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.1"),
|
testIP: netip.MustParseAddr("192.168.1.1"),
|
||||||
expected: true,
|
expected: true,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -73,7 +74,19 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(24, 32),
|
Mask: net.CIDRMask(24, 32),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("192.168.1.2"),
|
testIP: netip.MustParseAddr("192.168.1.2"),
|
||||||
|
expected: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Local IP doesn't match - addresses 32 apart",
|
||||||
|
setupAddr: wgaddr.Address{
|
||||||
|
IP: net.ParseIP("192.168.1.1"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("192.168.1.0"),
|
||||||
|
Mask: net.CIDRMask(24, 32),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
testIP: netip.MustParseAddr("192.168.1.33"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -85,7 +98,7 @@ func TestLocalIPManager(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(64, 128),
|
Mask: net.CIDRMask(64, 128),
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
testIP: net.ParseIP("fe80::1"),
|
testIP: netip.MustParseAddr("fe80::1"),
|
||||||
expected: false,
|
expected: false,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -174,7 +187,7 @@ func TestLocalIPManager_AllInterfaces(t *testing.T) {
|
|||||||
t.Logf("Testing %d IPs", len(tests))
|
t.Logf("Testing %d IPs", len(tests))
|
||||||
for _, tt := range tests {
|
for _, tt := range tests {
|
||||||
t.Run(tt.ip, func(t *testing.T) {
|
t.Run(tt.ip, func(t *testing.T) {
|
||||||
result := manager.IsLocalIP(net.ParseIP(tt.ip))
|
result := manager.IsLocalIP(netip.MustParseAddr(tt.ip))
|
||||||
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
require.Equal(t, tt.expected, result, "IP: %s", tt.ip)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -191,10 +204,8 @@ func BenchmarkIPChecks(b *testing.B) {
|
|||||||
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
interfaces[i] = net.IPv4(10, 0, byte(i>>8), byte(i))
|
||||||
}
|
}
|
||||||
|
|
||||||
// Setup bitmap version
|
// Setup bitmap
|
||||||
bitmapManager := &localIPManager{
|
bitmapManager := newLocalIPManager()
|
||||||
ipv4Bitmap: [1 << 16]uint32{},
|
|
||||||
}
|
|
||||||
for _, ip := range interfaces[:8] { // Add half of IPs
|
for _, ip := range interfaces[:8] { // Add half of IPs
|
||||||
bitmapManager.setBitmapBit(ip)
|
bitmapManager.setBitmapBit(ip)
|
||||||
}
|
}
|
||||||
@@ -247,7 +258,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
|||||||
|
|
||||||
// Create two managers - one checks WG IP first, other checks it last
|
// Create two managers - one checks WG IP first, other checks it last
|
||||||
b.Run("WG_First", func(b *testing.B) {
|
b.Run("WG_First", func(b *testing.B) {
|
||||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
bm := newLocalIPManager()
|
||||||
bm.setBitmapBit(wgIP)
|
bm.setBitmapBit(wgIP)
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
@@ -256,7 +267,7 @@ func BenchmarkWGPosition(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
b.Run("WG_Last", func(b *testing.B) {
|
b.Run("WG_Last", func(b *testing.B) {
|
||||||
bm := &localIPManager{ipv4Bitmap: [1 << 16]uint32{}}
|
bm := newLocalIPManager()
|
||||||
// Fill with other IPs first
|
// Fill with other IPs first
|
||||||
for i := 0; i < 15; i++ {
|
for i := 0; i < 15; i++ {
|
||||||
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
bm.setBitmapBit(net.IPv4(10, 0, byte(i>>8), byte(i)))
|
||||||
|
|||||||
@@ -1,4 +1,4 @@
|
|||||||
// Package logger provides a high-performance, non-blocking logger for userspace networking
|
// Package log provides a high-performance, non-blocking logger for userspace networking
|
||||||
package log
|
package log
|
||||||
|
|
||||||
import (
|
import (
|
||||||
@@ -13,13 +13,12 @@ import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
maxBatchSize = 1024 * 16 // 16KB max batch size
|
maxBatchSize = 1024 * 16
|
||||||
maxMessageSize = 1024 * 2 // 2KB per message
|
maxMessageSize = 1024 * 2
|
||||||
bufferSize = 1024 * 256 // 256KB ring buffer
|
|
||||||
defaultFlushInterval = 2 * time.Second
|
defaultFlushInterval = 2 * time.Second
|
||||||
|
logChannelSize = 1000
|
||||||
)
|
)
|
||||||
|
|
||||||
// Level represents log severity
|
|
||||||
type Level uint32
|
type Level uint32
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -42,32 +41,37 @@ var levelStrings = map[Level]string{
|
|||||||
LevelTrace: "TRAC",
|
LevelTrace: "TRAC",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type logMessage struct {
|
||||||
|
level Level
|
||||||
|
format string
|
||||||
|
args []any
|
||||||
|
}
|
||||||
|
|
||||||
// Logger is a high-performance, non-blocking logger
|
// Logger is a high-performance, non-blocking logger
|
||||||
type Logger struct {
|
type Logger struct {
|
||||||
output io.Writer
|
output io.Writer
|
||||||
level atomic.Uint32
|
level atomic.Uint32
|
||||||
buffer *ringBuffer
|
msgChannel chan logMessage
|
||||||
shutdown chan struct{}
|
shutdown chan struct{}
|
||||||
closeOnce sync.Once
|
closeOnce sync.Once
|
||||||
wg sync.WaitGroup
|
wg sync.WaitGroup
|
||||||
|
|
||||||
// Reusable buffer pool for formatting messages
|
|
||||||
bufPool sync.Pool
|
bufPool sync.Pool
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// NewFromLogrus creates a new Logger that writes to the same output as the given logrus logger
|
||||||
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
||||||
l := &Logger{
|
l := &Logger{
|
||||||
output: logrusLogger.Out,
|
output: logrusLogger.Out,
|
||||||
buffer: newRingBuffer(bufferSize),
|
msgChannel: make(chan logMessage, logChannelSize),
|
||||||
shutdown: make(chan struct{}),
|
shutdown: make(chan struct{}),
|
||||||
bufPool: sync.Pool{
|
bufPool: sync.Pool{
|
||||||
New: func() interface{} {
|
New: func() any {
|
||||||
// Pre-allocate buffer for message formatting
|
|
||||||
b := make([]byte, 0, maxMessageSize)
|
b := make([]byte, 0, maxMessageSize)
|
||||||
return &b
|
return &b
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
logrusLevel := logrusLogger.GetLevel()
|
logrusLevel := logrusLogger.GetLevel()
|
||||||
l.level.Store(uint32(logrusLevel))
|
l.level.Store(uint32(logrusLevel))
|
||||||
level := levelStrings[Level(logrusLevel)]
|
level := levelStrings[Level(logrusLevel)]
|
||||||
@@ -79,97 +83,149 @@ func NewFromLogrus(logrusLogger *log.Logger) *Logger {
|
|||||||
return l
|
return l
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// SetLevel sets the logging level
|
||||||
func (l *Logger) SetLevel(level Level) {
|
func (l *Logger) SetLevel(level Level) {
|
||||||
l.level.Store(uint32(level))
|
l.level.Store(uint32(level))
|
||||||
|
|
||||||
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
log.Debugf("Set uspfilter logger loglevel to %v", levelStrings[level])
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...interface{}) {
|
func (l *Logger) log(level Level, format string, args ...any) {
|
||||||
*buf = (*buf)[:0]
|
select {
|
||||||
|
case l.msgChannel <- logMessage{level: level, format: format, args: args}:
|
||||||
// Timestamp
|
default:
|
||||||
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Level
|
|
||||||
*buf = append(*buf, levelStrings[level]...)
|
|
||||||
*buf = append(*buf, ' ')
|
|
||||||
|
|
||||||
// Message
|
|
||||||
if len(args) > 0 {
|
|
||||||
*buf = append(*buf, fmt.Sprintf(format, args...)...)
|
|
||||||
} else {
|
|
||||||
*buf = append(*buf, format...)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
*buf = append(*buf, '\n')
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) log(level Level, format string, args ...interface{}) {
|
// Error logs a message at error level
|
||||||
bufp := l.bufPool.Get().(*[]byte)
|
func (l *Logger) Error(format string, args ...any) {
|
||||||
l.formatMessage(bufp, level, format, args...)
|
|
||||||
|
|
||||||
if len(*bufp) > maxMessageSize {
|
|
||||||
*bufp = (*bufp)[:maxMessageSize]
|
|
||||||
}
|
|
||||||
_, _ = l.buffer.Write(*bufp)
|
|
||||||
|
|
||||||
l.bufPool.Put(bufp)
|
|
||||||
}
|
|
||||||
|
|
||||||
func (l *Logger) Error(format string, args ...interface{}) {
|
|
||||||
if l.level.Load() >= uint32(LevelError) {
|
if l.level.Load() >= uint32(LevelError) {
|
||||||
l.log(LevelError, format, args...)
|
l.log(LevelError, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Warn(format string, args ...interface{}) {
|
// Warn logs a message at warning level
|
||||||
|
func (l *Logger) Warn(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelWarn) {
|
if l.level.Load() >= uint32(LevelWarn) {
|
||||||
l.log(LevelWarn, format, args...)
|
l.log(LevelWarn, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Info(format string, args ...interface{}) {
|
// Info logs a message at info level
|
||||||
|
func (l *Logger) Info(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelInfo) {
|
if l.level.Load() >= uint32(LevelInfo) {
|
||||||
l.log(LevelInfo, format, args...)
|
l.log(LevelInfo, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Debug(format string, args ...interface{}) {
|
// Debug logs a message at debug level
|
||||||
|
func (l *Logger) Debug(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelDebug) {
|
if l.level.Load() >= uint32(LevelDebug) {
|
||||||
l.log(LevelDebug, format, args...)
|
l.log(LevelDebug, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (l *Logger) Trace(format string, args ...interface{}) {
|
// Trace logs a message at trace level
|
||||||
|
func (l *Logger) Trace(format string, args ...any) {
|
||||||
if l.level.Load() >= uint32(LevelTrace) {
|
if l.level.Load() >= uint32(LevelTrace) {
|
||||||
l.log(LevelTrace, format, args...)
|
l.log(LevelTrace, format, args...)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// worker periodically flushes the buffer
|
func (l *Logger) formatMessage(buf *[]byte, level Level, format string, args ...any) {
|
||||||
|
*buf = (*buf)[:0]
|
||||||
|
*buf = time.Now().AppendFormat(*buf, "2006-01-02T15:04:05-07:00")
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
*buf = append(*buf, levelStrings[level]...)
|
||||||
|
*buf = append(*buf, ' ')
|
||||||
|
|
||||||
|
var msg string
|
||||||
|
if len(args) > 0 {
|
||||||
|
msg = fmt.Sprintf(format, args...)
|
||||||
|
} else {
|
||||||
|
msg = format
|
||||||
|
}
|
||||||
|
*buf = append(*buf, msg...)
|
||||||
|
*buf = append(*buf, '\n')
|
||||||
|
|
||||||
|
if len(*buf) > maxMessageSize {
|
||||||
|
*buf = (*buf)[:maxMessageSize]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processMessage handles a single log message and adds it to the buffer
|
||||||
|
func (l *Logger) processMessage(msg logMessage, buffer *[]byte) {
|
||||||
|
bufp := l.bufPool.Get().(*[]byte)
|
||||||
|
defer l.bufPool.Put(bufp)
|
||||||
|
|
||||||
|
l.formatMessage(bufp, msg.level, msg.format, msg.args...)
|
||||||
|
|
||||||
|
if len(*buffer)+len(*bufp) > maxBatchSize {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
*buffer = append(*buffer, *bufp...)
|
||||||
|
}
|
||||||
|
|
||||||
|
// flushBuffer writes the accumulated buffer to output
|
||||||
|
func (l *Logger) flushBuffer(buffer *[]byte) {
|
||||||
|
if len(*buffer) > 0 {
|
||||||
|
_, _ = l.output.Write(*buffer)
|
||||||
|
*buffer = (*buffer)[:0]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// processBatch processes as many messages as possible without blocking
|
||||||
|
func (l *Logger) processBatch(buffer *[]byte) {
|
||||||
|
for len(*buffer) < maxBatchSize {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
default:
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleShutdown manages the graceful shutdown sequence with timeout
|
||||||
|
func (l *Logger) handleShutdown(buffer *[]byte) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 500*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
|
||||||
|
for {
|
||||||
|
select {
|
||||||
|
case msg := <-l.msgChannel:
|
||||||
|
l.processMessage(msg, buffer)
|
||||||
|
case <-ctx.Done():
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
if len(l.msgChannel) == 0 {
|
||||||
|
l.flushBuffer(buffer)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// worker is the main goroutine that processes log messages
|
||||||
func (l *Logger) worker() {
|
func (l *Logger) worker() {
|
||||||
defer l.wg.Done()
|
defer l.wg.Done()
|
||||||
|
|
||||||
ticker := time.NewTicker(defaultFlushInterval)
|
ticker := time.NewTicker(defaultFlushInterval)
|
||||||
defer ticker.Stop()
|
defer ticker.Stop()
|
||||||
|
|
||||||
buf := make([]byte, 0, maxBatchSize)
|
buffer := make([]byte, 0, maxBatchSize)
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-l.shutdown:
|
case <-l.shutdown:
|
||||||
|
l.handleShutdown(&buffer)
|
||||||
return
|
return
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
// Read accumulated messages
|
l.flushBuffer(&buffer)
|
||||||
n, _ := l.buffer.Read(buf[:cap(buf)])
|
case msg := <-l.msgChannel:
|
||||||
if n == 0 {
|
l.processMessage(msg, &buffer)
|
||||||
continue
|
l.processBatch(&buffer)
|
||||||
}
|
|
||||||
|
|
||||||
// Write batch
|
|
||||||
_, _ = l.output.Write(buf[:n])
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
121
client/firewall/uspfilter/log/log_test.go
Normal file
121
client/firewall/uspfilter/log/log_test.go
Normal file
@@ -0,0 +1,121 @@
|
|||||||
|
package log_test
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
|
||||||
|
"github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
|
)
|
||||||
|
|
||||||
|
type discard struct{}
|
||||||
|
|
||||||
|
func (d *discard) Write(p []byte) (n int, err error) {
|
||||||
|
return len(p), nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func BenchmarkLogger(b *testing.B) {
|
||||||
|
simpleMessage := "Connection established"
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4 // TCPStateEstablished
|
||||||
|
|
||||||
|
complexMessage := "Packet inspection result: protocol=%s, direction=%s, flags=0x%x, sequence=%d, acknowledged=%d, payload_size=%d, fragmented=%v, connection_id=%s"
|
||||||
|
protocol := "TCP"
|
||||||
|
direction := "outbound"
|
||||||
|
flags := uint16(0x18) // ACK + PSH
|
||||||
|
sequence := uint32(123456789)
|
||||||
|
acknowledged := uint32(987654321)
|
||||||
|
payloadSize := 1460
|
||||||
|
fragmented := false
|
||||||
|
connID := "f7a12b3e-c456-7890-d123-456789abcdef"
|
||||||
|
|
||||||
|
b.Run("SimpleMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(simpleMessage)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ConntrackMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
|
||||||
|
b.Run("ComplexMessage", func(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
logger.Trace(complexMessage, protocol, direction, flags, sequence, acknowledged, payloadSize, fragmented, connID)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerParallel tests the logger under concurrent load
|
||||||
|
func BenchmarkLoggerParallel(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
b.RunParallel(func(pb *testing.PB) {
|
||||||
|
for pb.Next() {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
// BenchmarkLoggerBurst tests how the logger handles bursts of messages
|
||||||
|
func BenchmarkLoggerBurst(b *testing.B) {
|
||||||
|
logger := createTestLogger()
|
||||||
|
defer cleanupLogger(logger)
|
||||||
|
|
||||||
|
conntrackMessage := "TCP connection %s:%d -> %s:%d state changed to %d"
|
||||||
|
srcIP := "192.168.1.1"
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstIP := "10.0.0.1"
|
||||||
|
dstPort := uint16(443)
|
||||||
|
state := 4
|
||||||
|
|
||||||
|
b.ResetTimer()
|
||||||
|
for i := 0; i < b.N; i++ {
|
||||||
|
for j := 0; j < 100; j++ {
|
||||||
|
logger.Trace(conntrackMessage, srcIP, srcPort, dstIP, dstPort, state)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createTestLogger() *log.Logger {
|
||||||
|
logrusLogger := logrus.New()
|
||||||
|
logrusLogger.SetOutput(&discard{})
|
||||||
|
logrusLogger.SetLevel(logrus.TraceLevel)
|
||||||
|
return log.NewFromLogrus(logrusLogger)
|
||||||
|
}
|
||||||
|
|
||||||
|
func cleanupLogger(logger *log.Logger) {
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), 100*time.Millisecond)
|
||||||
|
defer cancel()
|
||||||
|
_ = logger.Stop(ctx)
|
||||||
|
}
|
||||||
@@ -1,85 +0,0 @@
|
|||||||
package log
|
|
||||||
|
|
||||||
import "sync"
|
|
||||||
|
|
||||||
// ringBuffer is a simple ring buffer implementation
|
|
||||||
type ringBuffer struct {
|
|
||||||
buf []byte
|
|
||||||
size int
|
|
||||||
r, w int64 // Read and write positions
|
|
||||||
mu sync.Mutex
|
|
||||||
}
|
|
||||||
|
|
||||||
func newRingBuffer(size int) *ringBuffer {
|
|
||||||
return &ringBuffer{
|
|
||||||
buf: make([]byte, size),
|
|
||||||
size: size,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Write(p []byte) (n int, err error) {
|
|
||||||
if len(p) == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if len(p) > r.size {
|
|
||||||
p = p[:r.size]
|
|
||||||
}
|
|
||||||
|
|
||||||
n = len(p)
|
|
||||||
|
|
||||||
// Write data, handling wrap-around
|
|
||||||
pos := int(r.w % int64(r.size))
|
|
||||||
writeLen := min(len(p), r.size-pos)
|
|
||||||
copy(r.buf[pos:], p[:writeLen])
|
|
||||||
|
|
||||||
// If we have more data and need to wrap around
|
|
||||||
if writeLen < len(p) {
|
|
||||||
copy(r.buf, p[writeLen:])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update write position
|
|
||||||
r.w += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (r *ringBuffer) Read(p []byte) (n int, err error) {
|
|
||||||
r.mu.Lock()
|
|
||||||
defer r.mu.Unlock()
|
|
||||||
|
|
||||||
if r.w == r.r {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Calculate available data accounting for wraparound
|
|
||||||
available := int(r.w - r.r)
|
|
||||||
if available < 0 {
|
|
||||||
available += r.size
|
|
||||||
}
|
|
||||||
available = min(available, r.size)
|
|
||||||
|
|
||||||
// Limit read to buffer size
|
|
||||||
toRead := min(available, len(p))
|
|
||||||
if toRead == 0 {
|
|
||||||
return 0, nil
|
|
||||||
}
|
|
||||||
|
|
||||||
// Read data, handling wrap-around
|
|
||||||
pos := int(r.r % int64(r.size))
|
|
||||||
readLen := min(toRead, r.size-pos)
|
|
||||||
n = copy(p, r.buf[pos:pos+readLen])
|
|
||||||
|
|
||||||
// If we need more data and need to wrap around
|
|
||||||
if readLen < toRead {
|
|
||||||
n += copy(p[readLen:toRead], r.buf[:toRead-readLen])
|
|
||||||
}
|
|
||||||
|
|
||||||
// Update read position
|
|
||||||
r.r += int64(n)
|
|
||||||
|
|
||||||
return n, nil
|
|
||||||
}
|
|
||||||
@@ -1,7 +1,6 @@
|
|||||||
package uspfilter
|
package uspfilter
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
|
||||||
"net/netip"
|
"net/netip"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -12,14 +11,14 @@ import (
|
|||||||
// PeerRule to handle management of rules
|
// PeerRule to handle management of rules
|
||||||
type PeerRule struct {
|
type PeerRule struct {
|
||||||
id string
|
id string
|
||||||
ip net.IP
|
mgmtId []byte
|
||||||
|
ip netip.Addr
|
||||||
ipLayer gopacket.LayerType
|
ipLayer gopacket.LayerType
|
||||||
matchByIP bool
|
matchByIP bool
|
||||||
protoLayer gopacket.LayerType
|
protoLayer gopacket.LayerType
|
||||||
sPort *firewall.Port
|
sPort *firewall.Port
|
||||||
dPort *firewall.Port
|
dPort *firewall.Port
|
||||||
drop bool
|
drop bool
|
||||||
comment string
|
|
||||||
|
|
||||||
udpHook func([]byte) bool
|
udpHook func([]byte) bool
|
||||||
}
|
}
|
||||||
@@ -31,6 +30,7 @@ func (r *PeerRule) ID() string {
|
|||||||
|
|
||||||
type RouteRule struct {
|
type RouteRule struct {
|
||||||
id string
|
id string
|
||||||
|
mgmtId []byte
|
||||||
sources []netip.Prefix
|
sources []netip.Prefix
|
||||||
destination netip.Prefix
|
destination netip.Prefix
|
||||||
proto firewall.Protocol
|
proto firewall.Protocol
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package uspfilter
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -53,8 +53,8 @@ type TraceResult struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketTrace struct {
|
type PacketTrace struct {
|
||||||
SourceIP net.IP
|
SourceIP netip.Addr
|
||||||
DestinationIP net.IP
|
DestinationIP netip.Addr
|
||||||
Protocol string
|
Protocol string
|
||||||
SourcePort uint16
|
SourcePort uint16
|
||||||
DestinationPort uint16
|
DestinationPort uint16
|
||||||
@@ -72,8 +72,8 @@ type TCPState struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type PacketBuilder struct {
|
type PacketBuilder struct {
|
||||||
SrcIP net.IP
|
SrcIP netip.Addr
|
||||||
DstIP net.IP
|
DstIP netip.Addr
|
||||||
Protocol fw.Protocol
|
Protocol fw.Protocol
|
||||||
SrcPort uint16
|
SrcPort uint16
|
||||||
DstPort uint16
|
DstPort uint16
|
||||||
@@ -126,8 +126,8 @@ func (p *PacketBuilder) buildIPLayer() *layers.IPv4 {
|
|||||||
Version: 4,
|
Version: 4,
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
Protocol: layers.IPProtocol(getIPProtocolNumber(p.Protocol)),
|
||||||
SrcIP: p.SrcIP,
|
SrcIP: p.SrcIP.AsSlice(),
|
||||||
DstIP: p.DstIP,
|
DstIP: p.DstIP.AsSlice(),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -260,28 +260,30 @@ func (m *Manager) TracePacket(packetData []byte, direction fw.RuleDirection) *Pa
|
|||||||
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
return m.traceInbound(packetData, trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP net.IP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) traceInbound(packetData []byte, trace *PacketTrace, d *decoder, srcIP netip.Addr, dstIP netip.Addr) *PacketTrace {
|
||||||
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
if m.stateful && m.handleConntrackState(trace, d, srcIP, dstIP) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
if m.handleLocalDelivery(trace, packetData, d, srcIP, dstIP) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if !m.handleRouting(trace) {
|
if !m.handleRouting(trace) {
|
||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.nativeRouter {
|
if m.nativeRouter.Load() {
|
||||||
return m.handleNativeRouter(trace)
|
return m.handleNativeRouter(trace)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
return m.handleRouteACLs(trace, d, srcIP, dstIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleConntrackState(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
allowed := m.isValidTrackedConnection(d, srcIP, dstIP)
|
allowed := m.isValidTrackedConnection(d, srcIP, dstIP, 0)
|
||||||
msg := "No existing connection found"
|
msg := "No existing connection found"
|
||||||
if allowed {
|
if allowed {
|
||||||
msg = m.buildConntrackStateMessage(d)
|
msg = m.buildConntrackStateMessage(d)
|
||||||
@@ -309,32 +311,46 @@ func (m *Manager) buildConntrackStateMessage(d *decoder) string {
|
|||||||
return msg
|
return msg
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) handleLocalDelivery(trace *PacketTrace, packetData []byte, d *decoder, srcIP, dstIP netip.Addr) bool {
|
||||||
if !m.localForwarding {
|
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
||||||
trace.AddResult(StageRouting, "Local forwarding disabled", false)
|
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - local forwarding disabled", false)
|
ruleId, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
|
|
||||||
|
strRuleId := "<no id>"
|
||||||
|
if ruleId != nil {
|
||||||
|
strRuleId = string(ruleId)
|
||||||
|
}
|
||||||
|
msg := fmt.Sprintf("Allowed by peer ACL rules (%s)", strRuleId)
|
||||||
|
if blocked {
|
||||||
|
msg = fmt.Sprintf("Blocked by peer ACL rules (%s)", strRuleId)
|
||||||
|
trace.AddResult(StagePeerACL, msg, false)
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - ACL denied", false)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageRouting, "Packet destined for local delivery", true)
|
trace.AddResult(StagePeerACL, msg, true)
|
||||||
blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
|
||||||
|
|
||||||
msg := "Allowed by peer ACL rules"
|
|
||||||
if blocked {
|
|
||||||
msg = "Blocked by peer ACL rules"
|
|
||||||
}
|
|
||||||
trace.AddResult(StagePeerACL, msg, !blocked)
|
|
||||||
|
|
||||||
|
// Handle netstack mode
|
||||||
if m.netstack {
|
if m.netstack {
|
||||||
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", !blocked)
|
switch {
|
||||||
|
case !m.localForwarding:
|
||||||
|
trace.AddResult(StageCompleted, "Packet sent to virtual stack", true)
|
||||||
|
case m.forwarder.Load() != nil:
|
||||||
|
m.addForwardingResult(trace, "proxy-local", "127.0.0.1", true)
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
|
default:
|
||||||
|
trace.AddResult(StageCompleted, "Packet dropped - forwarder not initialized", false)
|
||||||
|
}
|
||||||
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
trace.AddResult(StageCompleted, msgProcessingCompleted, !blocked)
|
// In normal mode, packets are allowed through for local delivery
|
||||||
|
trace.AddResult(StageCompleted, msgProcessingCompleted, true)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
func (m *Manager) handleRouting(trace *PacketTrace) bool {
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled.Load() {
|
||||||
trace.AddResult(StageRouting, "Routing disabled", false)
|
trace.AddResult(StageRouting, "Routing disabled", false)
|
||||||
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
trace.AddResult(StageCompleted, "Packet dropped - routing disabled", false)
|
||||||
return false
|
return false
|
||||||
@@ -350,18 +366,23 @@ func (m *Manager) handleNativeRouter(trace *PacketTrace) *PacketTrace {
|
|||||||
return trace
|
return trace
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP net.IP) *PacketTrace {
|
func (m *Manager) handleRouteACLs(trace *PacketTrace, d *decoder, srcIP, dstIP netip.Addr) *PacketTrace {
|
||||||
proto := getProtocolFromPacket(d)
|
proto, _ := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
id, allowed := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort)
|
||||||
|
|
||||||
msg := "Allowed by route ACLs"
|
strId := string(id)
|
||||||
|
if id == nil {
|
||||||
|
strId = "<no id>"
|
||||||
|
}
|
||||||
|
|
||||||
|
msg := fmt.Sprintf("Allowed by route ACLs (%s)", strId)
|
||||||
if !allowed {
|
if !allowed {
|
||||||
msg = "Blocked by route ACLs"
|
msg = fmt.Sprintf("Blocked by route ACLs (%s)", strId)
|
||||||
}
|
}
|
||||||
trace.AddResult(StageRouteACL, msg, allowed)
|
trace.AddResult(StageRouteACL, msg, allowed)
|
||||||
|
|
||||||
if allowed && m.forwarder != nil {
|
if allowed && m.forwarder.Load() != nil {
|
||||||
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
m.addForwardingResult(trace, "proxy-remote", fmt.Sprintf("%s:%d", dstIP, dstPort), true)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -380,7 +401,7 @@ func (m *Manager) addForwardingResult(trace *PacketTrace, action, remoteAddr str
|
|||||||
|
|
||||||
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
func (m *Manager) traceOutbound(packetData []byte, trace *PacketTrace) *PacketTrace {
|
||||||
// will create or update the connection state
|
// will create or update the connection state
|
||||||
dropped := m.processOutgoingHooks(packetData)
|
dropped := m.processOutgoingHooks(packetData, 0)
|
||||||
if dropped {
|
if dropped {
|
||||||
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
trace.AddResult(StageCompleted, "Packet dropped by outgoing hook", false)
|
||||||
} else {
|
} else {
|
||||||
|
|||||||
440
client/firewall/uspfilter/tracer_test.go
Normal file
440
client/firewall/uspfilter/tracer_test.go
Normal file
@@ -0,0 +1,440 @@
|
|||||||
|
package uspfilter
|
||||||
|
|
||||||
|
import (
|
||||||
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
fw "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/conntrack"
|
||||||
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
)
|
||||||
|
|
||||||
|
func verifyTraceStages(t *testing.T, trace *PacketTrace, expectedStages []PacketStage) {
|
||||||
|
t.Logf("Trace results: %v", trace.Results)
|
||||||
|
actualStages := make([]PacketStage, 0, len(trace.Results))
|
||||||
|
for _, result := range trace.Results {
|
||||||
|
actualStages = append(actualStages, result.Stage)
|
||||||
|
t.Logf("Stage: %s, Message: %s, Allowed: %v", result.Stage, result.Message, result.Allowed)
|
||||||
|
}
|
||||||
|
|
||||||
|
require.ElementsMatch(t, expectedStages, actualStages, "Trace stages don't match expected stages")
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyFinalDisposition(t *testing.T, trace *PacketTrace, expectedAllowed bool) {
|
||||||
|
require.NotEmpty(t, trace.Results, "Trace should have results")
|
||||||
|
lastResult := trace.Results[len(trace.Results)-1]
|
||||||
|
require.Equal(t, StageCompleted, lastResult.Stage, "Last stage should be 'Completed'")
|
||||||
|
require.Equal(t, expectedAllowed, lastResult.Allowed, "Final disposition incorrect")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestTracePacket(t *testing.T) {
|
||||||
|
setupTracerTest := func(statefulMode bool) *Manager {
|
||||||
|
ifaceMock := &IFaceMock{
|
||||||
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
|
AddressFunc: func() wgaddr.Address {
|
||||||
|
return wgaddr.Address{
|
||||||
|
IP: net.ParseIP("100.10.0.100"),
|
||||||
|
Network: &net.IPNet{
|
||||||
|
IP: net.ParseIP("100.10.0.0"),
|
||||||
|
Mask: net.CIDRMask(16, 32),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
if !statefulMode {
|
||||||
|
m.stateful = false
|
||||||
|
}
|
||||||
|
|
||||||
|
return m
|
||||||
|
}
|
||||||
|
|
||||||
|
createPacketBuilder := func(srcIP, dstIP string, protocol fw.Protocol, srcPort, dstPort uint16, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
builder := &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: protocol,
|
||||||
|
SrcPort: srcPort,
|
||||||
|
DstPort: dstPort,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
|
||||||
|
if protocol == "tcp" {
|
||||||
|
builder.TCPState = &TCPState{SYN: true}
|
||||||
|
}
|
||||||
|
|
||||||
|
return builder
|
||||||
|
}
|
||||||
|
|
||||||
|
createICMPPacketBuilder := func(srcIP, dstIP string, icmpType, icmpCode uint8, direction fw.RuleDirection) *PacketBuilder {
|
||||||
|
return &PacketBuilder{
|
||||||
|
SrcIP: netip.MustParseAddr(srcIP),
|
||||||
|
DstIP: netip.MustParseAddr(dstIP),
|
||||||
|
Protocol: "icmp",
|
||||||
|
ICMPType: icmpType,
|
||||||
|
ICMPCode: icmpCode,
|
||||||
|
Direction: direction,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
setup func(*Manager)
|
||||||
|
packetBuilder func() *PacketBuilder
|
||||||
|
expectedStages []PacketStage
|
||||||
|
expectedAllow bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = true
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "LocalTraffic_WithoutForwarder",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.netstack = true
|
||||||
|
m.localForwarding = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLAllowed",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
m.forwarder.Store(&forwarder.Forwarder{})
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionAccept)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_ACLDenied",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
|
src := netip.PrefixFrom(netip.AddrFrom4([4]byte{1, 1, 1, 1}), 32)
|
||||||
|
dst := netip.PrefixFrom(netip.AddrFrom4([4]byte{172, 17, 0, 2}), 32)
|
||||||
|
_, err := m.AddRouteFiltering(nil, []netip.Prefix{src}, dst, fw.ProtocolTCP, nil, &fw.Port{Values: []uint16{80}}, fw.ActionDrop)
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_NativeRouter",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(true)
|
||||||
|
m.nativeRouter.Store(true)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageRouteACL,
|
||||||
|
StageForwarding,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "RoutedTraffic_RoutingDisabled",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.routingEnabled.Store(false)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "172.17.0.2", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ConnectionTracking_Hit",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
srcIP := netip.MustParseAddr("100.10.0.100")
|
||||||
|
dstIP := netip.MustParseAddr("1.1.1.1")
|
||||||
|
srcPort := uint16(12345)
|
||||||
|
dstPort := uint16(80)
|
||||||
|
|
||||||
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, srcPort, dstPort, conntrack.TCPSyn, 0)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
pb := createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 80, 12345, fw.RuleDirectionIN)
|
||||||
|
pb.TCPState = &TCPState{SYN: true, ACK: true}
|
||||||
|
return pb
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "OutboundTraffic",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("100.10.0.100", "1.1.1.1", "tcp", 12345, 80, fw.RuleDirectionOUT)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPEchoRequest",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 8, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "ICMPDestinationUnreachable",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolICMP
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createICMPPacketBuilder("1.1.1.1", "100.10.0.100", 3, 0, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithoutHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolUDP
|
||||||
|
port := &fw.Port{Values: []uint16{53}}
|
||||||
|
action := fw.ActionAccept
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "UDPTraffic_WithHook",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
hookFunc := func([]byte) bool {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
m.AddUDPPacketHook(true, netip.MustParseAddr("1.1.1.1"), 53, hookFunc)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "udp", 12345, 53, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageConntrack,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "StatefulDisabled_NoTracking",
|
||||||
|
setup: func(m *Manager) {
|
||||||
|
m.stateful = false
|
||||||
|
|
||||||
|
ip := net.ParseIP("1.1.1.1")
|
||||||
|
proto := fw.ProtocolTCP
|
||||||
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
|
action := fw.ActionDrop
|
||||||
|
_, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
|
require.NoError(t, err)
|
||||||
|
},
|
||||||
|
packetBuilder: func() *PacketBuilder {
|
||||||
|
return createPacketBuilder("1.1.1.1", "100.10.0.100", "tcp", 12345, 80, fw.RuleDirectionIN)
|
||||||
|
},
|
||||||
|
expectedStages: []PacketStage{
|
||||||
|
StageReceived,
|
||||||
|
StageRouting,
|
||||||
|
StagePeerACL,
|
||||||
|
StageCompleted,
|
||||||
|
},
|
||||||
|
expectedAllow: false,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
m := setupTracerTest(true)
|
||||||
|
|
||||||
|
tc.setup(m)
|
||||||
|
|
||||||
|
require.True(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("100.10.0.100")),
|
||||||
|
"100.10.0.100 should be recognized as a local IP")
|
||||||
|
require.False(t, m.localipmanager.IsLocalIP(netip.MustParseAddr("172.17.0.2")),
|
||||||
|
"172.17.0.2 should not be recognized as a local IP")
|
||||||
|
|
||||||
|
pb := tc.packetBuilder()
|
||||||
|
|
||||||
|
trace, err := m.TracePacketFromBuilder(pb)
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
verifyTraceStages(t, trace, tc.expectedStages)
|
||||||
|
verifyFinalDisposition(t, trace, tc.expectedAllow)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
"strconv"
|
"strconv"
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
"github.com/google/gopacket/layers"
|
"github.com/google/gopacket/layers"
|
||||||
@@ -22,6 +23,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/forwarder"
|
||||||
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
nblog "github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -65,9 +67,9 @@ func (r RouteRules) Sort() {
|
|||||||
// Manager userspace firewall manager
|
// Manager userspace firewall manager
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
// outgoingRules is used for hooks only
|
// outgoingRules is used for hooks only
|
||||||
outgoingRules map[string]RuleSet
|
outgoingRules map[netip.Addr]RuleSet
|
||||||
// incomingRules is used for filtering and hooks
|
// incomingRules is used for filtering and hooks
|
||||||
incomingRules map[string]RuleSet
|
incomingRules map[netip.Addr]RuleSet
|
||||||
routeRules RouteRules
|
routeRules RouteRules
|
||||||
wgNetwork *net.IPNet
|
wgNetwork *net.IPNet
|
||||||
decoders sync.Pool
|
decoders sync.Pool
|
||||||
@@ -79,9 +81,9 @@ type Manager struct {
|
|||||||
// indicates whether server routes are disabled
|
// indicates whether server routes are disabled
|
||||||
disableServerRoutes bool
|
disableServerRoutes bool
|
||||||
// indicates whether we forward packets not destined for ourselves
|
// indicates whether we forward packets not destined for ourselves
|
||||||
routingEnabled bool
|
routingEnabled atomic.Bool
|
||||||
// indicates whether we leave forwarding and filtering to the native firewall
|
// indicates whether we leave forwarding and filtering to the native firewall
|
||||||
nativeRouter bool
|
nativeRouter atomic.Bool
|
||||||
// indicates whether we track outbound connections
|
// indicates whether we track outbound connections
|
||||||
stateful bool
|
stateful bool
|
||||||
// indicates whether wireguards runs in netstack mode
|
// indicates whether wireguards runs in netstack mode
|
||||||
@@ -94,8 +96,9 @@ type Manager struct {
|
|||||||
udpTracker *conntrack.UDPTracker
|
udpTracker *conntrack.UDPTracker
|
||||||
icmpTracker *conntrack.ICMPTracker
|
icmpTracker *conntrack.ICMPTracker
|
||||||
tcpTracker *conntrack.TCPTracker
|
tcpTracker *conntrack.TCPTracker
|
||||||
forwarder *forwarder.Forwarder
|
forwarder atomic.Pointer[forwarder.Forwarder]
|
||||||
logger *nblog.Logger
|
logger *nblog.Logger
|
||||||
|
flowLogger nftypes.FlowLogger
|
||||||
}
|
}
|
||||||
|
|
||||||
// decoder for packages
|
// decoder for packages
|
||||||
@@ -112,16 +115,16 @@ type decoder struct {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create userspace firewall manager constructor
|
// Create userspace firewall manager constructor
|
||||||
func Create(iface common.IFaceMapper, disableServerRoutes bool) (*Manager, error) {
|
func Create(iface common.IFaceMapper, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
return create(iface, nil, disableServerRoutes)
|
return create(iface, nil, disableServerRoutes, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
func CreateWithNativeFirewall(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
if nativeFirewall == nil {
|
if nativeFirewall == nil {
|
||||||
return nil, errors.New("native firewall is nil")
|
return nil, errors.New("native firewall is nil")
|
||||||
}
|
}
|
||||||
|
|
||||||
mgr, err := create(iface, nativeFirewall, disableServerRoutes)
|
mgr, err := create(iface, nativeFirewall, disableServerRoutes, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
@@ -148,7 +151,7 @@ func parseCreateEnv() (bool, bool) {
|
|||||||
return disableConntrack, enableLocalForwarding
|
return disableConntrack, enableLocalForwarding
|
||||||
}
|
}
|
||||||
|
|
||||||
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool) (*Manager, error) {
|
func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableServerRoutes bool, flowLogger nftypes.FlowLogger) (*Manager, error) {
|
||||||
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
disableConntrack, enableLocalForwarding := parseCreateEnv()
|
||||||
|
|
||||||
m := &Manager{
|
m := &Manager{
|
||||||
@@ -166,17 +169,18 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
},
|
},
|
||||||
},
|
},
|
||||||
nativeFirewall: nativeFirewall,
|
nativeFirewall: nativeFirewall,
|
||||||
outgoingRules: make(map[string]RuleSet),
|
outgoingRules: make(map[netip.Addr]RuleSet),
|
||||||
incomingRules: make(map[string]RuleSet),
|
incomingRules: make(map[netip.Addr]RuleSet),
|
||||||
wgIface: iface,
|
wgIface: iface,
|
||||||
localipmanager: newLocalIPManager(),
|
localipmanager: newLocalIPManager(),
|
||||||
disableServerRoutes: disableServerRoutes,
|
disableServerRoutes: disableServerRoutes,
|
||||||
routingEnabled: false,
|
|
||||||
stateful: !disableConntrack,
|
stateful: !disableConntrack,
|
||||||
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
logger: nblog.NewFromLogrus(log.StandardLogger()),
|
||||||
|
flowLogger: flowLogger,
|
||||||
netstack: netstack.IsEnabled(),
|
netstack: netstack.IsEnabled(),
|
||||||
localForwarding: enableLocalForwarding,
|
localForwarding: enableLocalForwarding,
|
||||||
}
|
}
|
||||||
|
m.routingEnabled.Store(false)
|
||||||
|
|
||||||
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
if err := m.localipmanager.UpdateLocalIPs(iface); err != nil {
|
||||||
return nil, fmt.Errorf("update local IPs: %w", err)
|
return nil, fmt.Errorf("update local IPs: %w", err)
|
||||||
@@ -185,9 +189,9 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
if disableConntrack {
|
if disableConntrack {
|
||||||
log.Info("conntrack is disabled")
|
log.Info("conntrack is disabled")
|
||||||
} else {
|
} else {
|
||||||
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger)
|
m.udpTracker = conntrack.NewUDPTracker(conntrack.DefaultUDPTimeout, m.logger, flowLogger)
|
||||||
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger)
|
m.icmpTracker = conntrack.NewICMPTracker(conntrack.DefaultICMPTimeout, m.logger, flowLogger)
|
||||||
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger)
|
m.tcpTracker = conntrack.NewTCPTracker(conntrack.DefaultTCPTimeout, m.logger, flowLogger)
|
||||||
}
|
}
|
||||||
|
|
||||||
// netstack needs the forwarder for local traffic
|
// netstack needs the forwarder for local traffic
|
||||||
@@ -208,7 +212,7 @@ func create(iface common.IFaceMapper, nativeFirewall firewall.Manager, disableSe
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
||||||
if m.forwarder == nil {
|
if m.forwarder.Load() == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
wgPrefix, err := netip.ParsePrefix(iface.Address().Network.String())
|
||||||
@@ -218,6 +222,7 @@ func (m *Manager) blockInvalidRouted(iface common.IFaceMapper) error {
|
|||||||
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
log.Debugf("blocking invalid routed traffic for %s", wgPrefix)
|
||||||
|
|
||||||
if _, err := m.AddRouteFiltering(
|
if _, err := m.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
[]netip.Prefix{netip.PrefixFrom(netip.IPv4Unspecified(), 0)},
|
||||||
wgPrefix,
|
wgPrefix,
|
||||||
firewall.ProtocolALL,
|
firewall.ProtocolALL,
|
||||||
@@ -251,20 +256,20 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
switch {
|
switch {
|
||||||
case disableUspRouting:
|
case disableUspRouting:
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
log.Info("userspace routing is disabled")
|
log.Info("userspace routing is disabled")
|
||||||
|
|
||||||
case m.disableServerRoutes:
|
case m.disableServerRoutes:
|
||||||
// if server routes are disabled we will let packets pass to the native stack
|
// if server routes are disabled we will let packets pass to the native stack
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = true
|
m.nativeRouter.Store(true)
|
||||||
|
|
||||||
log.Info("server routes are disabled")
|
log.Info("server routes are disabled")
|
||||||
|
|
||||||
case forceUserspaceRouter:
|
case forceUserspaceRouter:
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing is forced")
|
log.Info("userspace routing is forced")
|
||||||
|
|
||||||
@@ -272,19 +277,19 @@ func (m *Manager) determineRouting() error {
|
|||||||
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
// if the OS supports routing natively, then we don't need to filter/route ourselves
|
||||||
// netstack mode won't support native routing as there is no interface
|
// netstack mode won't support native routing as there is no interface
|
||||||
|
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = true
|
m.nativeRouter.Store(true)
|
||||||
|
|
||||||
log.Info("native routing is enabled")
|
log.Info("native routing is enabled")
|
||||||
|
|
||||||
default:
|
default:
|
||||||
m.routingEnabled = true
|
m.routingEnabled.Store(true)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
log.Info("userspace routing enabled by default")
|
log.Info("userspace routing enabled by default")
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.routingEnabled && !m.nativeRouter {
|
if m.routingEnabled.Load() && !m.nativeRouter.Load() {
|
||||||
return m.initForwarder()
|
return m.initForwarder()
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -293,24 +298,24 @@ func (m *Manager) determineRouting() error {
|
|||||||
|
|
||||||
// initForwarder initializes the forwarder, it disables routing on errors
|
// initForwarder initializes the forwarder, it disables routing on errors
|
||||||
func (m *Manager) initForwarder() error {
|
func (m *Manager) initForwarder() error {
|
||||||
if m.forwarder != nil {
|
if m.forwarder.Load() != nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
// Only supported in userspace mode as we need to inject packets back into wireguard directly
|
||||||
intf := m.wgIface.GetWGDevice()
|
intf := m.wgIface.GetWGDevice()
|
||||||
if intf == nil {
|
if intf == nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
return errors.New("forwarding not supported")
|
return errors.New("forwarding not supported")
|
||||||
}
|
}
|
||||||
|
|
||||||
forwarder, err := forwarder.New(m.wgIface, m.logger, m.netstack)
|
forwarder, err := forwarder.New(m.wgIface, m.logger, m.flowLogger, m.netstack)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
return fmt.Errorf("create forwarder: %w", err)
|
return fmt.Errorf("create forwarder: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.forwarder = forwarder
|
m.forwarder.Store(forwarder)
|
||||||
|
|
||||||
log.Debug("forwarder initialized")
|
log.Debug("forwarder initialized")
|
||||||
|
|
||||||
@@ -326,7 +331,7 @@ func (m *Manager) IsServerRouteSupported() bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddNatRule(pair)
|
return m.nativeFirewall.AddNatRule(pair)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -337,7 +342,7 @@ func (m *Manager) AddNatRule(pair firewall.RouterPair) error {
|
|||||||
|
|
||||||
// RemoveNatRule removes a routing firewall rule
|
// RemoveNatRule removes a routing firewall rule
|
||||||
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.RemoveNatRule(pair)
|
return m.nativeFirewall.RemoveNatRule(pair)
|
||||||
}
|
}
|
||||||
return nil
|
return nil
|
||||||
@@ -348,25 +353,31 @@ func (m *Manager) RemoveNatRule(pair firewall.RouterPair) error {
|
|||||||
// If comment argument is empty firewall manager should set
|
// If comment argument is empty firewall manager should set
|
||||||
// rule ID as comment for the rule
|
// rule ID as comment for the rule
|
||||||
func (m *Manager) AddPeerFiltering(
|
func (m *Manager) AddPeerFiltering(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
sPort *firewall.Port,
|
sPort *firewall.Port,
|
||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
_ string,
|
_ string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
|
// TODO: fix in upper layers
|
||||||
|
i, ok := netip.AddrFromSlice(ip)
|
||||||
|
if !ok {
|
||||||
|
return nil, fmt.Errorf("invalid IP: %s", ip)
|
||||||
|
}
|
||||||
|
|
||||||
|
i = i.Unmap()
|
||||||
r := PeerRule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
mgmtId: id,
|
||||||
|
ip: i,
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
matchByIP: true,
|
matchByIP: true,
|
||||||
drop: action == firewall.ActionDrop,
|
drop: action == firewall.ActionDrop,
|
||||||
comment: comment,
|
|
||||||
}
|
}
|
||||||
if ipNormalized := ip.To4(); ipNormalized != nil {
|
if i.Is4() {
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
r.ip = ipNormalized
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
if s := r.ip.String(); s == "0.0.0.0" || s == "::" {
|
||||||
@@ -391,15 +402,16 @@ func (m *Manager) AddPeerFiltering(
|
|||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(RuleSet)
|
m.incomingRules[r.ip] = make(RuleSet)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip][r.id] = r
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
return []firewall.Rule{&r}, nil
|
return []firewall.Rule{&r}, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) AddRouteFiltering(
|
func (m *Manager) AddRouteFiltering(
|
||||||
|
id []byte,
|
||||||
sources []netip.Prefix,
|
sources []netip.Prefix,
|
||||||
destination netip.Prefix,
|
destination netip.Prefix,
|
||||||
proto firewall.Protocol,
|
proto firewall.Protocol,
|
||||||
@@ -407,16 +419,15 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
dPort *firewall.Port,
|
dPort *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
) (firewall.Rule, error) {
|
) (firewall.Rule, error) {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.AddRouteFiltering(sources, destination, proto, sPort, dPort, action)
|
return m.nativeFirewall.AddRouteFiltering(id, sources, destination, proto, sPort, dPort, action)
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
|
||||||
defer m.mutex.Unlock()
|
|
||||||
|
|
||||||
ruleID := uuid.New().String()
|
ruleID := uuid.New().String()
|
||||||
rule := RouteRule{
|
rule := RouteRule{
|
||||||
|
// TODO: consolidate these IDs
|
||||||
id: ruleID,
|
id: ruleID,
|
||||||
|
mgmtId: id,
|
||||||
sources: sources,
|
sources: sources,
|
||||||
destination: destination,
|
destination: destination,
|
||||||
proto: proto,
|
proto: proto,
|
||||||
@@ -425,14 +436,16 @@ func (m *Manager) AddRouteFiltering(
|
|||||||
action: action,
|
action: action,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
m.mutex.Lock()
|
||||||
m.routeRules = append(m.routeRules, rule)
|
m.routeRules = append(m.routeRules, rule)
|
||||||
m.routeRules.Sort()
|
m.routeRules.Sort()
|
||||||
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return &rule, nil
|
return &rule, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
func (m *Manager) DeleteRouteRule(rule firewall.Rule) error {
|
||||||
if m.nativeRouter && m.nativeFirewall != nil {
|
if m.nativeRouter.Load() && m.nativeFirewall != nil {
|
||||||
return m.nativeFirewall.DeleteRouteRule(rule)
|
return m.nativeFirewall.DeleteRouteRule(rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -461,10 +474,10 @@ func (m *Manager) DeletePeerRule(rule firewall.Rule) error {
|
|||||||
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
return fmt.Errorf("delete rule: invalid rule type: %T", rule)
|
||||||
}
|
}
|
||||||
|
|
||||||
if _, ok := m.incomingRules[r.ip.String()][r.id]; !ok {
|
if _, ok := m.incomingRules[r.ip][r.id]; !ok {
|
||||||
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
return fmt.Errorf("delete rule: no rule with such id: %v", r.id)
|
||||||
}
|
}
|
||||||
delete(m.incomingRules[r.ip.String()], r.id)
|
delete(m.incomingRules[r.ip], r.id)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -497,13 +510,13 @@ func (m *Manager) DeleteDNATRule(rule firewall.Rule) error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing filter outgoing packets
|
// DropOutgoing filter outgoing packets
|
||||||
func (m *Manager) DropOutgoing(packetData []byte) bool {
|
func (m *Manager) DropOutgoing(packetData []byte, size int) bool {
|
||||||
return m.processOutgoingHooks(packetData)
|
return m.processOutgoingHooks(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming filter incoming packets
|
// DropIncoming filter incoming packets
|
||||||
func (m *Manager) DropIncoming(packetData []byte) bool {
|
func (m *Manager) DropIncoming(packetData []byte, size int) bool {
|
||||||
return m.dropFilter(packetData)
|
return m.dropFilter(packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// UpdateLocalIPs updates the list of local IPs
|
// UpdateLocalIPs updates the list of local IPs
|
||||||
@@ -511,10 +524,7 @@ func (m *Manager) UpdateLocalIPs() error {
|
|||||||
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
return m.localipmanager.UpdateLocalIPs(m.wgIface)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
func (m *Manager) processOutgoingHooks(packetData []byte, size int) bool {
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
@@ -527,52 +537,37 @@ func (m *Manager) processOutgoingHooks(packetData []byte) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if !srcIP.IsValid() {
|
||||||
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track all protocols if stateful mode is enabled
|
if d.decoded[1] == layers.LayerTypeUDP && m.udpHooksDrop(uint16(d.udp.DstPort), dstIP, packetData) {
|
||||||
if m.stateful {
|
return true
|
||||||
switch d.decoded[1] {
|
|
||||||
case layers.LayerTypeUDP:
|
|
||||||
m.trackUDPOutbound(d, srcIP, dstIP)
|
|
||||||
case layers.LayerTypeTCP:
|
|
||||||
m.trackTCPOutbound(d, srcIP, dstIP)
|
|
||||||
case layers.LayerTypeICMPv4:
|
|
||||||
m.trackICMPOutbound(d, srcIP, dstIP)
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Process UDP hooks even if stateful mode is disabled
|
if m.stateful {
|
||||||
if d.decoded[1] == layers.LayerTypeUDP {
|
m.trackOutbound(d, srcIP, dstIP, size)
|
||||||
return m.checkUDPHooks(d, dstIP, packetData)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP net.IP) {
|
func (m *Manager) extractIPs(d *decoder) (srcIP, dstIP netip.Addr) {
|
||||||
switch d.decoded[0] {
|
switch d.decoded[0] {
|
||||||
case layers.LayerTypeIPv4:
|
case layers.LayerTypeIPv4:
|
||||||
return d.ip4.SrcIP, d.ip4.DstIP
|
src, _ := netip.AddrFromSlice(d.ip4.SrcIP)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip4.DstIP)
|
||||||
|
return src, dst
|
||||||
case layers.LayerTypeIPv6:
|
case layers.LayerTypeIPv6:
|
||||||
return d.ip6.SrcIP, d.ip6.DstIP
|
src, _ := netip.AddrFromSlice(d.ip6.SrcIP)
|
||||||
|
dst, _ := netip.AddrFromSlice(d.ip6.DstIP)
|
||||||
|
return src, dst
|
||||||
default:
|
default:
|
||||||
return nil, nil
|
return netip.Addr{}, netip.Addr{}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackTCPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
|
||||||
flags := getTCPFlags(&d.tcp)
|
|
||||||
m.tcpTracker.TrackOutbound(
|
|
||||||
srcIP,
|
|
||||||
dstIP,
|
|
||||||
uint16(d.tcp.SrcPort),
|
|
||||||
uint16(d.tcp.DstPort),
|
|
||||||
flags,
|
|
||||||
)
|
|
||||||
}
|
|
||||||
|
|
||||||
func getTCPFlags(tcp *layers.TCP) uint8 {
|
func getTCPFlags(tcp *layers.TCP) uint8 {
|
||||||
var flags uint8
|
var flags uint8
|
||||||
if tcp.SYN {
|
if tcp.SYN {
|
||||||
@@ -596,100 +591,153 @@ func getTCPFlags(tcp *layers.TCP) uint8 {
|
|||||||
return flags
|
return flags
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) trackUDPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
func (m *Manager) trackOutbound(d *decoder, srcIP, dstIP netip.Addr, size int) {
|
||||||
m.udpTracker.TrackOutbound(
|
transport := d.decoded[1]
|
||||||
srcIP,
|
switch transport {
|
||||||
dstIP,
|
case layers.LayerTypeUDP:
|
||||||
uint16(d.udp.SrcPort),
|
m.udpTracker.TrackOutbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), size)
|
||||||
uint16(d.udp.DstPort),
|
case layers.LayerTypeTCP:
|
||||||
)
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackOutbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, size)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.icmpTracker.TrackOutbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, size)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) checkUDPHooks(d *decoder, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) trackInbound(d *decoder, srcIP, dstIP netip.Addr, ruleID []byte, size int) {
|
||||||
for _, ipKey := range []string{dstIP.String(), "0.0.0.0", "::"} {
|
transport := d.decoded[1]
|
||||||
if rules, exists := m.outgoingRules[ipKey]; exists {
|
switch transport {
|
||||||
|
case layers.LayerTypeUDP:
|
||||||
|
m.udpTracker.TrackInbound(srcIP, dstIP, uint16(d.udp.SrcPort), uint16(d.udp.DstPort), ruleID, size)
|
||||||
|
case layers.LayerTypeTCP:
|
||||||
|
flags := getTCPFlags(&d.tcp)
|
||||||
|
m.tcpTracker.TrackInbound(srcIP, dstIP, uint16(d.tcp.SrcPort), uint16(d.tcp.DstPort), flags, ruleID, size)
|
||||||
|
case layers.LayerTypeICMPv4:
|
||||||
|
m.icmpTracker.TrackInbound(srcIP, dstIP, d.icmp4.Id, d.icmp4.TypeCode, ruleID, size)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// udpHooksDrop checks if any UDP hooks should drop the packet
|
||||||
|
func (m *Manager) udpHooksDrop(dport uint16, dstIP netip.Addr, packetData []byte) bool {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
|
// Check specific destination IP first
|
||||||
|
if rules, exists := m.outgoingRules[dstIP]; exists {
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.udpHook != nil && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
return rule.udpHook(packetData)
|
return rule.udpHook(packetData)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (m *Manager) trackICMPOutbound(d *decoder, srcIP, dstIP net.IP) {
|
// Check IPv4 unspecified address
|
||||||
if d.icmp4.TypeCode.Type() == layers.ICMPv4TypeEchoRequest {
|
if rules, exists := m.outgoingRules[netip.IPv4Unspecified()]; exists {
|
||||||
m.icmpTracker.TrackOutbound(
|
for _, rule := range rules {
|
||||||
srcIP,
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
dstIP,
|
return rule.udpHook(packetData)
|
||||||
d.icmp4.Id,
|
|
||||||
d.icmp4.Seq,
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check IPv6 unspecified address
|
||||||
|
if rules, exists := m.outgoingRules[netip.IPv6Unspecified()]; exists {
|
||||||
|
for _, rule := range rules {
|
||||||
|
if rule.udpHook != nil && portsMatch(rule.dPort, dport) {
|
||||||
|
return rule.udpHook(packetData)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
// dropFilter implements filtering logic for incoming packets.
|
// dropFilter implements filtering logic for incoming packets.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) dropFilter(packetData []byte) bool {
|
func (m *Manager) dropFilter(packetData []byte, size int) bool {
|
||||||
m.mutex.RLock()
|
|
||||||
defer m.mutex.RUnlock()
|
|
||||||
|
|
||||||
d := m.decoders.Get().(*decoder)
|
d := m.decoders.Get().(*decoder)
|
||||||
defer m.decoders.Put(d)
|
defer m.decoders.Put(d)
|
||||||
|
|
||||||
if !m.isValidPacket(d, packetData) {
|
valid, fragment := m.isValidPacket(d, packetData)
|
||||||
|
if !valid {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
srcIP, dstIP := m.extractIPs(d)
|
srcIP, dstIP := m.extractIPs(d)
|
||||||
if srcIP == nil {
|
if !srcIP.IsValid() {
|
||||||
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
m.logger.Error("Unknown network layer: %v", d.decoded[0])
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO: pass fragments of routed packets to forwarder
|
||||||
|
if fragment {
|
||||||
|
m.logger.Trace("packet is a fragment: src=%v dst=%v id=%v flags=%v",
|
||||||
|
srcIP, dstIP, d.ip4.Id, d.ip4.Flags)
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
// For all inbound traffic, first check if it matches a tracked connection.
|
// For all inbound traffic, first check if it matches a tracked connection.
|
||||||
// This must happen before any other filtering because the packets are statefully tracked.
|
// This must happen before any other filtering because the packets are statefully tracked.
|
||||||
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP) {
|
if m.stateful && m.isValidTrackedConnection(d, srcIP, dstIP, size) {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.localipmanager.IsLocalIP(dstIP) {
|
if m.localipmanager.IsLocalIP(dstIP) {
|
||||||
return m.handleLocalTraffic(d, srcIP, dstIP, packetData)
|
return m.handleLocalTraffic(d, srcIP, dstIP, packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData)
|
return m.handleRoutedTraffic(d, srcIP, dstIP, packetData, size)
|
||||||
}
|
}
|
||||||
|
|
||||||
// handleLocalTraffic handles local traffic.
|
// handleLocalTraffic handles local traffic.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) handleLocalTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||||
if m.peerACLsBlock(srcIP, packetData, m.incomingRules, d) {
|
ruleID, blocked := m.peerACLsBlock(srcIP, packetData, m.incomingRules, d)
|
||||||
m.logger.Trace("Dropping local packet (ACL denied): src=%s dst=%s",
|
if blocked {
|
||||||
srcIP, dstIP)
|
_, pnum := getProtocolFromPacket(d)
|
||||||
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
|
m.logger.Trace("Dropping local packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
|
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeDrop,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: pnum,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
// TODO: icmp type/code
|
||||||
|
RxPackets: 1,
|
||||||
|
RxBytes: uint64(size),
|
||||||
|
})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// if running in netstack mode we need to pass this to the forwarder
|
// if running in netstack mode we need to pass this to the forwarder
|
||||||
if m.netstack {
|
if m.netstack && m.localForwarding {
|
||||||
return m.handleNetstackLocalTraffic(packetData)
|
return m.handleNetstackLocalTraffic(packetData)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// track inbound packets to get the correct direction and session id for flows
|
||||||
|
m.trackInbound(d, srcIP, dstIP, ruleID, size)
|
||||||
|
|
||||||
|
// pass to either native or virtual stack (to be picked up by listeners)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
||||||
if !m.localForwarding {
|
|
||||||
// pass to virtual tcp/ip stack to be picked up by listeners
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
if m.forwarder == nil {
|
fwd := m.forwarder.Load()
|
||||||
|
if fwd == nil {
|
||||||
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
m.logger.Trace("Dropping local packet (forwarder not initialized)")
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
||||||
m.logger.Error("Failed to inject local packet: %v", err)
|
m.logger.Error("Failed to inject local packet: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -699,47 +747,68 @@ func (m *Manager) handleNetstackLocalTraffic(packetData []byte) bool {
|
|||||||
|
|
||||||
// handleRoutedTraffic handles routed traffic.
|
// handleRoutedTraffic handles routed traffic.
|
||||||
// If it returns true, the packet should be dropped.
|
// If it returns true, the packet should be dropped.
|
||||||
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP net.IP, packetData []byte) bool {
|
func (m *Manager) handleRoutedTraffic(d *decoder, srcIP, dstIP netip.Addr, packetData []byte, size int) bool {
|
||||||
// Drop if routing is disabled
|
// Drop if routing is disabled
|
||||||
if !m.routingEnabled {
|
if !m.routingEnabled.Load() {
|
||||||
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
m.logger.Trace("Dropping routed packet (routing disabled): src=%s dst=%s",
|
||||||
srcIP, dstIP)
|
srcIP, dstIP)
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pass to native stack if native router is enabled or forced
|
// Pass to native stack if native router is enabled or forced
|
||||||
if m.nativeRouter {
|
if m.nativeRouter.Load() {
|
||||||
|
m.trackInbound(d, srcIP, dstIP, nil, size)
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
proto := getProtocolFromPacket(d)
|
proto, pnum := getProtocolFromPacket(d)
|
||||||
srcPort, dstPort := getPortsFromPacket(d)
|
srcPort, dstPort := getPortsFromPacket(d)
|
||||||
|
|
||||||
if !m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort) {
|
if ruleID, pass := m.routeACLsPass(srcIP, dstIP, proto, srcPort, dstPort); !pass {
|
||||||
m.logger.Trace("Dropping routed packet (ACL denied): src=%s:%d dst=%s:%d proto=%v",
|
m.logger.Trace("Dropping routed packet (ACL denied): rule_id=%s proto=%v src=%s:%d dst=%s:%d",
|
||||||
srcIP, srcPort, dstIP, dstPort, proto)
|
ruleID, pnum, srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
|
m.flowLogger.StoreEvent(nftypes.EventFields{
|
||||||
|
FlowID: uuid.New(),
|
||||||
|
Type: nftypes.TypeDrop,
|
||||||
|
RuleID: ruleID,
|
||||||
|
Direction: nftypes.Ingress,
|
||||||
|
Protocol: pnum,
|
||||||
|
SourceIP: srcIP,
|
||||||
|
DestIP: dstIP,
|
||||||
|
SourcePort: srcPort,
|
||||||
|
DestPort: dstPort,
|
||||||
|
// TODO: icmp type/code
|
||||||
|
RxPackets: 1,
|
||||||
|
RxBytes: uint64(size),
|
||||||
|
})
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
// Let forwarder handle the packet if it passed route ACLs
|
// Let forwarder handle the packet if it passed route ACLs
|
||||||
if err := m.forwarder.InjectIncomingPacket(packetData); err != nil {
|
fwd := m.forwarder.Load()
|
||||||
m.logger.Error("Failed to inject incoming packet: %v", err)
|
if fwd == nil {
|
||||||
|
m.logger.Trace("failed to forward routed packet (forwarder not initialized)")
|
||||||
|
} else {
|
||||||
|
if err := fwd.InjectIncomingPacket(packetData); err != nil {
|
||||||
|
m.logger.Error("Failed to inject routed packet: %v", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
|
// Forwarded packets shouldn't reach the native stack, hence they won't be visible in a packet capture
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
func getProtocolFromPacket(d *decoder) firewall.Protocol {
|
func getProtocolFromPacket(d *decoder) (firewall.Protocol, nftypes.Protocol) {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
return firewall.ProtocolTCP
|
return firewall.ProtocolTCP, nftypes.TCP
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
return firewall.ProtocolUDP
|
return firewall.ProtocolUDP, nftypes.UDP
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return firewall.ProtocolICMP
|
return firewall.ProtocolICMP, nftypes.ICMP
|
||||||
default:
|
default:
|
||||||
return firewall.ProtocolALL
|
return firewall.ProtocolALL, nftypes.ProtocolUnknown
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -754,20 +823,35 @@ func getPortsFromPacket(d *decoder) (srcPort, dstPort uint16) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidPacket(d *decoder, packetData []byte) bool {
|
// isValidPacket checks if the packet is valid.
|
||||||
|
// It returns true, false if the packet is valid and not a fragment.
|
||||||
|
// It returns true, true if the packet is a fragment and valid.
|
||||||
|
func (m *Manager) isValidPacket(d *decoder, packetData []byte) (bool, bool) {
|
||||||
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
if err := d.parser.DecodeLayers(packetData, &d.decoded); err != nil {
|
||||||
m.logger.Trace("couldn't decode packet, err: %s", err)
|
m.logger.Trace("couldn't decode packet, err: %s", err)
|
||||||
return false
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(d.decoded) < 2 {
|
l := len(d.decoded)
|
||||||
m.logger.Trace("packet doesn't have network and transport layers")
|
|
||||||
return false
|
// L3 and L4 are mandatory
|
||||||
|
if l >= 2 {
|
||||||
|
return true, false
|
||||||
}
|
}
|
||||||
return true
|
|
||||||
|
// Fragments are also valid
|
||||||
|
if l == 1 && d.decoded[0] == layers.LayerTypeIPv4 {
|
||||||
|
ip4 := d.ip4
|
||||||
|
if ip4.Flags&layers.IPv4MoreFragments != 0 || ip4.FragOffset != 0 {
|
||||||
|
return true, true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
m.logger.Trace("packet doesn't have network and transport layers")
|
||||||
|
return false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool {
|
func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP netip.Addr, size int) bool {
|
||||||
switch d.decoded[1] {
|
switch d.decoded[1] {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
return m.tcpTracker.IsValidInbound(
|
return m.tcpTracker.IsValidInbound(
|
||||||
@@ -776,6 +860,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
uint16(d.tcp.SrcPort),
|
uint16(d.tcp.SrcPort),
|
||||||
uint16(d.tcp.DstPort),
|
uint16(d.tcp.DstPort),
|
||||||
getTCPFlags(&d.tcp),
|
getTCPFlags(&d.tcp),
|
||||||
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
@@ -784,6 +869,7 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
dstIP,
|
dstIP,
|
||||||
uint16(d.udp.SrcPort),
|
uint16(d.udp.SrcPort),
|
||||||
uint16(d.udp.DstPort),
|
uint16(d.udp.DstPort),
|
||||||
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
case layers.LayerTypeICMPv4:
|
case layers.LayerTypeICMPv4:
|
||||||
@@ -791,8 +877,8 @@ func (m *Manager) isValidTrackedConnection(d *decoder, srcIP, dstIP net.IP) bool
|
|||||||
srcIP,
|
srcIP,
|
||||||
dstIP,
|
dstIP,
|
||||||
d.icmp4.Id,
|
d.icmp4.Id,
|
||||||
d.icmp4.Seq,
|
|
||||||
d.icmp4.TypeCode.Type(),
|
d.icmp4.TypeCode.Type(),
|
||||||
|
size,
|
||||||
)
|
)
|
||||||
|
|
||||||
// TODO: ICMPv6
|
// TODO: ICMPv6
|
||||||
@@ -812,25 +898,27 @@ func (m *Manager) isSpecialICMP(d *decoder) bool {
|
|||||||
icmpType == layers.ICMPv4TypeTimeExceeded
|
icmpType == layers.ICMPv4TypeTimeExceeded
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) peerACLsBlock(srcIP net.IP, packetData []byte, rules map[string]RuleSet, d *decoder) bool {
|
func (m *Manager) peerACLsBlock(srcIP netip.Addr, packetData []byte, rules map[netip.Addr]RuleSet, d *decoder) ([]byte, bool) {
|
||||||
|
m.mutex.RLock()
|
||||||
|
defer m.mutex.RUnlock()
|
||||||
if m.isSpecialICMP(d) {
|
if m.isSpecialICMP(d) {
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules[srcIP.String()], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[srcIP], d); ok {
|
||||||
return filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules["0.0.0.0"], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv4Unspecified()], d); ok {
|
||||||
return filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
if filter, ok := validateRule(srcIP, packetData, rules["::"], d); ok {
|
if mgmtId, filter, ok := validateRule(srcIP, packetData, rules[netip.IPv6Unspecified()], d); ok {
|
||||||
return filter
|
return mgmtId, filter
|
||||||
}
|
}
|
||||||
|
|
||||||
// Default policy: DROP ALL
|
// Default policy: DROP ALL
|
||||||
return true
|
return nil, true
|
||||||
}
|
}
|
||||||
|
|
||||||
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
||||||
@@ -850,15 +938,15 @@ func portsMatch(rulePort *firewall.Port, packetPort uint16) bool {
|
|||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *decoder) (bool, bool) {
|
func validateRule(ip netip.Addr, packetData []byte, rules map[string]PeerRule, d *decoder) ([]byte, bool, bool) {
|
||||||
payloadLayer := d.decoded[1]
|
payloadLayer := d.decoded[1]
|
||||||
for _, rule := range rules {
|
for _, rule := range rules {
|
||||||
if rule.matchByIP && !ip.Equal(rule.ip) {
|
if rule.matchByIP && ip.Compare(rule.ip) != 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
if rule.protoLayer == layerTypeAll {
|
if rule.protoLayer == layerTypeAll {
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
|
|
||||||
if payloadLayer != rule.protoLayer {
|
if payloadLayer != rule.protoLayer {
|
||||||
@@ -868,39 +956,36 @@ func validateRule(ip net.IP, packetData []byte, rules map[string]PeerRule, d *de
|
|||||||
switch payloadLayer {
|
switch payloadLayer {
|
||||||
case layers.LayerTypeTCP:
|
case layers.LayerTypeTCP:
|
||||||
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
if portsMatch(rule.sPort, uint16(d.tcp.SrcPort)) && portsMatch(rule.dPort, uint16(d.tcp.DstPort)) {
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeUDP:
|
case layers.LayerTypeUDP:
|
||||||
// if rule has UDP hook (and if we are here we match this rule)
|
// if rule has UDP hook (and if we are here we match this rule)
|
||||||
// we ignore rule.drop and call this hook
|
// we ignore rule.drop and call this hook
|
||||||
if rule.udpHook != nil {
|
if rule.udpHook != nil {
|
||||||
return rule.udpHook(packetData), true
|
return rule.mgmtId, rule.udpHook(packetData), true
|
||||||
}
|
}
|
||||||
|
|
||||||
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
if portsMatch(rule.sPort, uint16(d.udp.SrcPort)) && portsMatch(rule.dPort, uint16(d.udp.DstPort)) {
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
case layers.LayerTypeICMPv4, layers.LayerTypeICMPv6:
|
||||||
return rule.drop, true
|
return rule.mgmtId, rule.drop, true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false, false
|
return nil, false, false
|
||||||
}
|
}
|
||||||
|
|
||||||
// routeACLsPass returns treu if the packet is allowed by the route ACLs
|
// routeACLsPass returns true if the packet is allowed by the route ACLs
|
||||||
func (m *Manager) routeACLsPass(srcIP, dstIP net.IP, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
func (m *Manager) routeACLsPass(srcIP, dstIP netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) ([]byte, bool) {
|
||||||
m.mutex.RLock()
|
m.mutex.RLock()
|
||||||
defer m.mutex.RUnlock()
|
defer m.mutex.RUnlock()
|
||||||
|
|
||||||
srcAddr := netip.AddrFrom4([4]byte(srcIP.To4()))
|
|
||||||
dstAddr := netip.AddrFrom4([4]byte(dstIP.To4()))
|
|
||||||
|
|
||||||
for _, rule := range m.routeRules {
|
for _, rule := range m.routeRules {
|
||||||
if m.ruleMatches(rule, srcAddr, dstAddr, proto, srcPort, dstPort) {
|
if matches := m.ruleMatches(rule, srcIP, dstIP, proto, srcPort, dstPort); matches {
|
||||||
return rule.action == firewall.ActionAccept
|
return rule.mgmtId, rule.action == firewall.ActionAccept
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return false
|
return nil, false
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
func (m *Manager) ruleMatches(rule RouteRule, srcAddr, dstAddr netip.Addr, proto firewall.Protocol, srcPort, dstPort uint16) bool {
|
||||||
@@ -940,36 +1025,32 @@ func (m *Manager) SetNetwork(network *net.IPNet) {
|
|||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not
|
// Hook function returns flag which indicates should be the matched package dropped or not
|
||||||
func (m *Manager) AddUDPPacketHook(
|
func (m *Manager) AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string {
|
||||||
in bool, ip net.IP, dPort uint16, hook func([]byte) bool,
|
|
||||||
) string {
|
|
||||||
r := PeerRule{
|
r := PeerRule{
|
||||||
id: uuid.New().String(),
|
id: uuid.New().String(),
|
||||||
ip: ip,
|
ip: ip,
|
||||||
protoLayer: layers.LayerTypeUDP,
|
protoLayer: layers.LayerTypeUDP,
|
||||||
dPort: &firewall.Port{Values: []uint16{dPort}},
|
dPort: &firewall.Port{Values: []uint16{dPort}},
|
||||||
ipLayer: layers.LayerTypeIPv6,
|
ipLayer: layers.LayerTypeIPv6,
|
||||||
comment: fmt.Sprintf("UDP Hook direction: %v, ip:%v, dport:%d", in, ip, dPort),
|
|
||||||
udpHook: hook,
|
udpHook: hook,
|
||||||
}
|
}
|
||||||
|
|
||||||
if ip.To4() != nil {
|
if ip.Is4() {
|
||||||
r.ipLayer = layers.LayerTypeIPv4
|
r.ipLayer = layers.LayerTypeIPv4
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
if in {
|
if in {
|
||||||
if _, ok := m.incomingRules[r.ip.String()]; !ok {
|
if _, ok := m.incomingRules[r.ip]; !ok {
|
||||||
m.incomingRules[r.ip.String()] = make(map[string]PeerRule)
|
m.incomingRules[r.ip] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.incomingRules[r.ip.String()][r.id] = r
|
m.incomingRules[r.ip][r.id] = r
|
||||||
} else {
|
} else {
|
||||||
if _, ok := m.outgoingRules[r.ip.String()]; !ok {
|
if _, ok := m.outgoingRules[r.ip]; !ok {
|
||||||
m.outgoingRules[r.ip.String()] = make(map[string]PeerRule)
|
m.outgoingRules[r.ip] = make(map[string]PeerRule)
|
||||||
}
|
}
|
||||||
m.outgoingRules[r.ip.String()][r.id] = r
|
m.outgoingRules[r.ip][r.id] = r
|
||||||
}
|
}
|
||||||
|
|
||||||
m.mutex.Unlock()
|
m.mutex.Unlock()
|
||||||
|
|
||||||
return r.id
|
return r.id
|
||||||
@@ -1017,20 +1098,21 @@ func (m *Manager) DisableRouting() error {
|
|||||||
m.mutex.Lock()
|
m.mutex.Lock()
|
||||||
defer m.mutex.Unlock()
|
defer m.mutex.Unlock()
|
||||||
|
|
||||||
if m.forwarder == nil {
|
fwder := m.forwarder.Load()
|
||||||
|
if fwder == nil {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.routingEnabled = false
|
m.routingEnabled.Store(false)
|
||||||
m.nativeRouter = false
|
m.nativeRouter.Store(false)
|
||||||
|
|
||||||
// don't stop forwarder if in use by netstack
|
// don't stop forwarder if in use by netstack
|
||||||
if m.netstack && m.localForwarding {
|
if m.netstack && m.localForwarding {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
m.forwarder.Stop()
|
fwder.Stop()
|
||||||
m.forwarder = nil
|
m.forwarder.Store(nil)
|
||||||
|
|
||||||
log.Debug("forwarder stopped")
|
log.Debug("forwarder stopped")
|
||||||
|
|
||||||
|
|||||||
@@ -93,8 +93,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: false,
|
stateful: false,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Single rule allowing all traffic
|
// Single rule allowing all traffic
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil,
|
_, err := m.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolALL, nil, nil, fw.ActionAccept, "")
|
||||||
fw.ActionAccept, "", "allow all")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
desc: "Baseline: Single 'allow all' rule without connection tracking",
|
||||||
@@ -114,10 +113,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Add explicit rules matching return traffic pattern
|
// Add explicit rules matching return traffic pattern
|
||||||
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
for i := 0; i < 1000; i++ { // Simulate realistic ruleset size
|
||||||
ip := generateRandomIPs(1)[0]
|
ip := generateRandomIPs(1)[0]
|
||||||
_, err := m.AddPeerFiltering(ip, fw.ProtocolTCP,
|
_, err := m.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
|
ip,
|
||||||
|
fw.ProtocolTCP,
|
||||||
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
&fw.Port{Values: []uint16{uint16(1024 + i)}},
|
||||||
&fw.Port{Values: []uint16{80}},
|
&fw.Port{Values: []uint16{80}},
|
||||||
fw.ActionAccept, "", "explicit return")
|
fw.ActionAccept,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
@@ -128,8 +132,15 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
stateful: true,
|
stateful: true,
|
||||||
setupFunc: func(m *Manager) {
|
setupFunc: func(m *Manager) {
|
||||||
// Add some basic rules but rely on state for established connections
|
// Add some basic rules but rely on state for established connections
|
||||||
_, err := m.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP, nil, nil,
|
_, err := m.AddPeerFiltering(
|
||||||
fw.ActionDrop, "", "default drop")
|
nil,
|
||||||
|
net.ParseIP("0.0.0.0"),
|
||||||
|
fw.ProtocolTCP,
|
||||||
|
nil,
|
||||||
|
nil,
|
||||||
|
fw.ActionDrop,
|
||||||
|
"",
|
||||||
|
)
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
},
|
},
|
||||||
desc: "Connection tracking with established connections",
|
desc: "Connection tracking with established connections",
|
||||||
@@ -158,7 +169,7 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
// Create manager and basic setup
|
// Create manager and basic setup
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -182,13 +193,13 @@ func BenchmarkCoreFiltering(b *testing.B) {
|
|||||||
|
|
||||||
// For stateful scenarios, establish the connection
|
// For stateful scenarios, establish the connection
|
||||||
if sc.stateful {
|
if sc.stateful {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Measure inbound packet processing
|
// Measure inbound packet processing
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -203,7 +214,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
b.Run(fmt.Sprintf("conns_%d", count), func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -219,7 +230,7 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
for i := 0; i < count; i++ {
|
for i := 0; i < count; i++ {
|
||||||
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
outbound := generatePacket(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, layers.IPProtocolTCP)
|
uint16(1024+i), 80, layers.IPProtocolTCP)
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Test packet
|
// Test packet
|
||||||
@@ -227,11 +238,11 @@ func BenchmarkStateScaling(b *testing.B) {
|
|||||||
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
testIn := generatePacket(b, dstIPs[0], srcIPs[0], 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
// First establish our test connection
|
// First establish our test connection
|
||||||
manager.processOutgoingHooks(testOut)
|
manager.processOutgoingHooks(testOut, 0)
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(testIn)
|
manager.dropFilter(testIn, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -251,7 +262,7 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -267,12 +278,12 @@ func BenchmarkEstablishmentOverhead(b *testing.B) {
|
|||||||
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
inbound := generatePacket(b, dstIP, srcIP, 80, 1024, layers.IPProtocolTCP)
|
||||||
|
|
||||||
if sc.established {
|
if sc.established {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -450,7 +461,7 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
b.Run(sc.name, func(b *testing.B) {
|
b.Run(sc.name, func(b *testing.B) {
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
b.Cleanup(func() {
|
b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -466,25 +477,25 @@ func BenchmarkRoutedNetworkReturn(b *testing.B) {
|
|||||||
// For stateful cases and established connections
|
// For stateful cases and established connections
|
||||||
if !strings.Contains(sc.name, "allow_non_wg") ||
|
if !strings.Contains(sc.name, "allow_non_wg") ||
|
||||||
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
(strings.Contains(sc.state, "established") || sc.state == "post_handshake") {
|
||||||
manager.processOutgoingHooks(outbound)
|
manager.processOutgoingHooks(outbound, 0)
|
||||||
|
|
||||||
// For TCP post-handshake, simulate full handshake
|
// For TCP post-handshake, simulate full handshake
|
||||||
if sc.state == "post_handshake" {
|
if sc.state == "post_handshake" {
|
||||||
// SYN
|
// SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
syn := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
synack := generateTCPPacketWithFlags(b, dstIP, srcIP, 80, 1024, uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
ack := generateTCPPacketWithFlags(b, srcIP, dstIP, 1024, 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
manager.dropFilter(inbound)
|
manager.dropFilter(inbound, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -577,7 +588,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -590,10 +601,7 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -616,17 +624,17 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
// Initial SYN
|
// Initial SYN
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
// SYN-ACK
|
// SYN-ACK
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
// ACK
|
// ACK
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Prepare test packets simulating bidirectional traffic
|
// Prepare test packets simulating bidirectional traffic
|
||||||
@@ -647,9 +655,9 @@ func BenchmarkLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
// First outbound data
|
// First outbound data
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
// Then inbound response - this is what we're actually measuring
|
// Then inbound response - this is what we're actually measuring
|
||||||
manager.dropFilter(inPackets[connIdx])
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -668,7 +676,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -681,10 +689,7 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
// Single rule to allow all return traffic from port 80
|
// Single rule to allow all return traffic from port 80
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -756,19 +761,19 @@ func BenchmarkShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Connection establishment
|
// Connection establishment
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
// Data transfer
|
// Data transfer
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.dropFilter(p.response)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
// Connection teardown
|
// Connection teardown
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -787,7 +792,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -799,10 +804,7 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
// Setup initial state based on scenario
|
// Setup initial state based on scenario
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -824,15 +826,15 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
for i := 0; i < sc.connCount; i++ {
|
for i := 0; i < sc.connCount; i++ {
|
||||||
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
syn := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
uint16(1024+i), 80, uint16(conntrack.TCPSyn))
|
||||||
manager.processOutgoingHooks(syn)
|
manager.processOutgoingHooks(syn, 0)
|
||||||
|
|
||||||
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
synack := generateTCPPacketWithFlags(b, dstIPs[i], srcIPs[i],
|
||||||
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
80, uint16(1024+i), uint16(conntrack.TCPSyn|conntrack.TCPAck))
|
||||||
manager.dropFilter(synack)
|
manager.dropFilter(synack, 0)
|
||||||
|
|
||||||
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
ack := generateTCPPacketWithFlags(b, srcIPs[i], dstIPs[i],
|
||||||
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
uint16(1024+i), 80, uint16(conntrack.TCPAck))
|
||||||
manager.processOutgoingHooks(ack)
|
manager.processOutgoingHooks(ack, 0)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Pre-generate test packets
|
// Pre-generate test packets
|
||||||
@@ -854,8 +856,8 @@ func BenchmarkParallelLongLivedConnections(b *testing.B) {
|
|||||||
counter++
|
counter++
|
||||||
|
|
||||||
// Simulate bidirectional traffic
|
// Simulate bidirectional traffic
|
||||||
manager.processOutgoingHooks(outPackets[connIdx])
|
manager.processOutgoingHooks(outPackets[connIdx], 0)
|
||||||
manager.dropFilter(inPackets[connIdx])
|
manager.dropFilter(inPackets[connIdx], 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -875,7 +877,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
|
|
||||||
manager, _ := Create(&IFaceMock{
|
manager, _ := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
defer b.Cleanup(func() {
|
defer b.Cleanup(func() {
|
||||||
require.NoError(b, manager.Close(nil))
|
require.NoError(b, manager.Close(nil))
|
||||||
})
|
})
|
||||||
@@ -886,10 +888,7 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
if sc.rules {
|
if sc.rules {
|
||||||
_, err := manager.AddPeerFiltering(net.ParseIP("0.0.0.0"), fw.ProtocolTCP,
|
_, err := manager.AddPeerFiltering(nil, net.ParseIP("0.0.0.0"), fw.ProtocolTCP, &fw.Port{Values: []uint16{80}}, nil, fw.ActionAccept, "")
|
||||||
&fw.Port{Values: []uint16{80}},
|
|
||||||
nil,
|
|
||||||
fw.ActionAccept, "", "return traffic")
|
|
||||||
require.NoError(b, err)
|
require.NoError(b, err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -951,17 +950,17 @@ func BenchmarkParallelShortLivedConnections(b *testing.B) {
|
|||||||
p := patterns[connIdx]
|
p := patterns[connIdx]
|
||||||
|
|
||||||
// Full connection lifecycle
|
// Full connection lifecycle
|
||||||
manager.processOutgoingHooks(p.syn)
|
manager.processOutgoingHooks(p.syn, 0)
|
||||||
manager.dropFilter(p.synAck)
|
manager.dropFilter(p.synAck, 0)
|
||||||
manager.processOutgoingHooks(p.ack)
|
manager.processOutgoingHooks(p.ack, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.request)
|
manager.processOutgoingHooks(p.request, 0)
|
||||||
manager.dropFilter(p.response)
|
manager.dropFilter(p.response, 0)
|
||||||
|
|
||||||
manager.processOutgoingHooks(p.finClient)
|
manager.processOutgoingHooks(p.finClient, 0)
|
||||||
manager.dropFilter(p.ackServer)
|
manager.dropFilter(p.ackServer, 0)
|
||||||
manager.dropFilter(p.finServer)
|
manager.dropFilter(p.finServer, 0)
|
||||||
manager.processOutgoingHooks(p.ackClient)
|
manager.processOutgoingHooks(p.ackClient, 0)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
})
|
})
|
||||||
@@ -1033,14 +1032,7 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rules {
|
for _, r := range rules {
|
||||||
_, err := manager.AddRouteFiltering(
|
_, err := manager.AddRouteFiltering(nil, r.sources, r.dest, r.proto, nil, r.port, fw.ActionAccept)
|
||||||
r.sources,
|
|
||||||
r.dest,
|
|
||||||
r.proto,
|
|
||||||
nil,
|
|
||||||
r.port,
|
|
||||||
fw.ActionAccept,
|
|
||||||
)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.Fatal(err)
|
b.Fatal(err)
|
||||||
}
|
}
|
||||||
@@ -1062,8 +1054,8 @@ func BenchmarkRouteACLs(b *testing.B) {
|
|||||||
b.ResetTimer()
|
b.ResetTimer()
|
||||||
for i := 0; i < b.N; i++ {
|
for i := 0; i < b.N; i++ {
|
||||||
for _, tc := range cases {
|
for _, tc := range cases {
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
manager.routeACLsPass(srcIP, dstIP, tc.proto, 0, tc.dstPort)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,7 +34,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotNil(t, manager)
|
require.NotNil(t, manager)
|
||||||
|
|
||||||
@@ -192,20 +192,20 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
|
|
||||||
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
t.Run("Implicit DROP (no rules)", func(t *testing.T) {
|
||||||
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
packet := createTestPacket(t, "100.10.0.1", "100.10.0.100", fw.ProtocolTCP, 12345, 443)
|
||||||
isDropped := manager.DropIncoming(packet)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
require.True(t, isDropped, "Packet should be dropped when no rules exist")
|
||||||
})
|
})
|
||||||
|
|
||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
rules, err := manager.AddPeerFiltering(
|
rules, err := manager.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.ParseIP(tc.ruleIP),
|
net.ParseIP(tc.ruleIP),
|
||||||
tc.ruleProto,
|
tc.ruleProto,
|
||||||
tc.ruleSrcPort,
|
tc.ruleSrcPort,
|
||||||
tc.ruleDstPort,
|
tc.ruleDstPort,
|
||||||
tc.ruleAction,
|
tc.ruleAction,
|
||||||
"",
|
"",
|
||||||
tc.name,
|
|
||||||
)
|
)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
require.NotEmpty(t, rules)
|
require.NotEmpty(t, rules)
|
||||||
@@ -217,7 +217,7 @@ func TestPeerACLFiltering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
packet := createTestPacket(t, tc.srcIP, tc.dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
isDropped := manager.DropIncoming(packet)
|
isDropped := manager.DropIncoming(packet, 0)
|
||||||
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
require.Equal(t, tc.shouldBeBlocked, isDropped)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -302,12 +302,12 @@ func setupRoutedManager(tb testing.TB, network string) *Manager {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(tb, manager.EnableRouting())
|
require.NoError(tb, manager.EnableRouting())
|
||||||
require.NoError(tb, err)
|
require.NoError(tb, err)
|
||||||
require.NotNil(tb, manager)
|
require.NotNil(tb, manager)
|
||||||
require.True(tb, manager.routingEnabled)
|
require.True(tb, manager.routingEnabled.Load())
|
||||||
require.False(tb, manager.nativeRouter)
|
require.False(tb, manager.nativeRouter.Load())
|
||||||
|
|
||||||
tb.Cleanup(func() {
|
tb.Cleanup(func() {
|
||||||
require.NoError(tb, manager.Close(nil))
|
require.NoError(tb, manager.Close(nil))
|
||||||
@@ -803,6 +803,7 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
for _, tc := range testCases {
|
for _, tc := range testCases {
|
||||||
t.Run(tc.name, func(t *testing.T) {
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
tc.rule.sources,
|
tc.rule.sources,
|
||||||
tc.rule.dest,
|
tc.rule.dest,
|
||||||
tc.rule.proto,
|
tc.rule.proto,
|
||||||
@@ -817,12 +818,12 @@ func TestRouteACLFiltering(t *testing.T) {
|
|||||||
require.NoError(t, manager.DeleteRouteRule(rule))
|
require.NoError(t, manager.DeleteRouteRule(rule))
|
||||||
})
|
})
|
||||||
|
|
||||||
srcIP := net.ParseIP(tc.srcIP)
|
srcIP := netip.MustParseAddr(tc.srcIP)
|
||||||
dstIP := net.ParseIP(tc.dstIP)
|
dstIP := netip.MustParseAddr(tc.dstIP)
|
||||||
|
|
||||||
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
// testing routeACLsPass only and not DropIncoming, as routed packets are dropped after being passed
|
||||||
// to the forwarder
|
// to the forwarder
|
||||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, tc.proto, tc.srcPort, tc.dstPort)
|
||||||
require.Equal(t, tc.shouldPass, isAllowed)
|
require.Equal(t, tc.shouldPass, isAllowed)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -985,6 +986,7 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
var rules []fw.Rule
|
var rules []fw.Rule
|
||||||
for _, r := range tc.rules {
|
for _, r := range tc.rules {
|
||||||
rule, err := manager.AddRouteFiltering(
|
rule, err := manager.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
r.sources,
|
r.sources,
|
||||||
r.dest,
|
r.dest,
|
||||||
r.proto,
|
r.proto,
|
||||||
@@ -1004,10 +1006,10 @@ func TestRouteACLOrder(t *testing.T) {
|
|||||||
})
|
})
|
||||||
|
|
||||||
for i, p := range tc.packets {
|
for i, p := range tc.packets {
|
||||||
srcIP := net.ParseIP(p.srcIP)
|
srcIP := netip.MustParseAddr(p.srcIP)
|
||||||
dstIP := net.ParseIP(p.dstIP)
|
dstIP := netip.MustParseAddr(p.dstIP)
|
||||||
|
|
||||||
isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
_, isAllowed := manager.routeACLsPass(srcIP, dstIP, p.proto, p.srcPort, p.dstPort)
|
||||||
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
require.Equal(t, p.shouldPass, isAllowed, "packet %d failed", i)
|
||||||
}
|
}
|
||||||
})
|
})
|
||||||
|
|||||||
@@ -3,6 +3,7 @@ package uspfilter
|
|||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
@@ -18,9 +19,11 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
"github.com/netbirdio/netbird/client/firewall/uspfilter/log"
|
||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
)
|
)
|
||||||
|
|
||||||
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
var logger = log.NewFromLogrus(logrus.StandardLogger())
|
||||||
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type IFaceMock struct {
|
type IFaceMock struct {
|
||||||
SetFilterFunc func(device.PacketFilter) error
|
SetFilterFunc func(device.PacketFilter) error
|
||||||
@@ -62,7 +65,7 @@ func TestManagerCreate(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -82,7 +85,7 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -92,9 +95,8 @@ func TestManagerAddPeerFiltering(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
rule, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
rule, err := m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -116,26 +118,25 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ip := net.ParseIP("192.168.1.1")
|
ip := netip.MustParseAddr("192.168.1.1")
|
||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule 2"
|
|
||||||
|
|
||||||
rule2, err := m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
rule2, err := m.AddPeerFiltering(nil, ip.AsSlice(), proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; !ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; !ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -149,7 +150,7 @@ func TestManagerDeleteRule(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, r := range rule2 {
|
for _, r := range rule2 {
|
||||||
if _, ok := m.incomingRules[ip.String()][r.ID()]; ok {
|
if _, ok := m.incomingRules[ip][r.ID()]; ok {
|
||||||
t.Errorf("rule2 is not in the incomingRules")
|
t.Errorf("rule2 is not in the incomingRules")
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -160,7 +161,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name string
|
name string
|
||||||
in bool
|
in bool
|
||||||
expDir fw.RuleDirection
|
expDir fw.RuleDirection
|
||||||
ip net.IP
|
ip netip.Addr
|
||||||
dPort uint16
|
dPort uint16
|
||||||
hook func([]byte) bool
|
hook func([]byte) bool
|
||||||
expectedID string
|
expectedID string
|
||||||
@@ -169,7 +170,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Outgoing UDP Packet Hook",
|
name: "Test Outgoing UDP Packet Hook",
|
||||||
in: false,
|
in: false,
|
||||||
expDir: fw.RuleDirectionOUT,
|
expDir: fw.RuleDirectionOUT,
|
||||||
ip: net.IPv4(10, 168, 0, 1),
|
ip: netip.MustParseAddr("10.168.0.1"),
|
||||||
dPort: 8000,
|
dPort: 8000,
|
||||||
hook: func([]byte) bool { return true },
|
hook: func([]byte) bool { return true },
|
||||||
},
|
},
|
||||||
@@ -177,7 +178,7 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
name: "Test Incoming UDP Packet Hook",
|
name: "Test Incoming UDP Packet Hook",
|
||||||
in: true,
|
in: true,
|
||||||
expDir: fw.RuleDirectionIN,
|
expDir: fw.RuleDirectionIN,
|
||||||
ip: net.IPv6loopback,
|
ip: netip.MustParseAddr("::1"),
|
||||||
dPort: 9000,
|
dPort: 9000,
|
||||||
hook: func([]byte) bool { return false },
|
hook: func([]byte) bool { return false },
|
||||||
},
|
},
|
||||||
@@ -187,18 +188,18 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Run(tt.name, func(t *testing.T) {
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
manager.AddUDPPacketHook(tt.in, tt.ip, tt.dPort, tt.hook)
|
||||||
|
|
||||||
var addedRule PeerRule
|
var addedRule PeerRule
|
||||||
if tt.in {
|
if tt.in {
|
||||||
if len(manager.incomingRules[tt.ip.String()]) != 1 {
|
if len(manager.incomingRules[tt.ip]) != 1 {
|
||||||
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
t.Errorf("expected 1 incoming rule, got %d", len(manager.incomingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.incomingRules[tt.ip.String()] {
|
for _, rule := range manager.incomingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
@@ -206,12 +207,12 @@ func TestAddUDPPacketHook(t *testing.T) {
|
|||||||
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
t.Errorf("expected 1 outgoing rule, got %d", len(manager.outgoingRules))
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
for _, rule := range manager.outgoingRules[tt.ip.String()] {
|
for _, rule := range manager.outgoingRules[tt.ip] {
|
||||||
addedRule = rule
|
addedRule = rule
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if !tt.ip.Equal(addedRule.ip) {
|
if tt.ip.Compare(addedRule.ip) != 0 {
|
||||||
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
t.Errorf("expected ip %s, got %s", tt.ip, addedRule.ip)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -236,7 +237,7 @@ func TestManagerReset(t *testing.T) {
|
|||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -246,9 +247,8 @@ func TestManagerReset(t *testing.T) {
|
|||||||
proto := fw.ProtocolTCP
|
proto := fw.ProtocolTCP
|
||||||
port := &fw.Port{Values: []uint16{80}}
|
port := &fw.Port{Values: []uint16{80}}
|
||||||
action := fw.ActionDrop
|
action := fw.ActionDrop
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, port, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, port, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -279,7 +279,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
},
|
},
|
||||||
}
|
}
|
||||||
|
|
||||||
m, err := Create(ifaceMock, false)
|
m, err := Create(ifaceMock, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to create Manager: %v", err)
|
t.Errorf("failed to create Manager: %v", err)
|
||||||
return
|
return
|
||||||
@@ -292,9 +292,8 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
ip := net.ParseIP("0.0.0.0")
|
ip := net.ParseIP("0.0.0.0")
|
||||||
proto := fw.ProtocolUDP
|
proto := fw.ProtocolUDP
|
||||||
action := fw.ActionAccept
|
action := fw.ActionAccept
|
||||||
comment := "Test rule"
|
|
||||||
|
|
||||||
_, err = m.AddPeerFiltering(ip, proto, nil, nil, action, "", comment)
|
_, err = m.AddPeerFiltering(nil, ip, proto, nil, nil, action, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("failed to add filtering: %v", err)
|
t.Errorf("failed to add filtering: %v", err)
|
||||||
return
|
return
|
||||||
@@ -328,7 +327,7 @@ func TestNotMatchByIP(t *testing.T) {
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if m.dropFilter(buf.Bytes()) {
|
if m.dropFilter(buf.Bytes(), 0) {
|
||||||
t.Errorf("expected packet to be accepted")
|
t.Errorf("expected packet to be accepted")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
@@ -347,7 +346,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// creating manager instance
|
// creating manager instance
|
||||||
manager, err := Create(iface, false)
|
manager, err := Create(iface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("Failed to create Manager: %s", err)
|
t.Fatalf("Failed to create Manager: %s", err)
|
||||||
}
|
}
|
||||||
@@ -357,7 +356,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
|
|
||||||
// Add a UDP packet hook
|
// Add a UDP packet hook
|
||||||
hookFunc := func(data []byte) bool { return true }
|
hookFunc := func(data []byte) bool { return true }
|
||||||
hookID := manager.AddUDPPacketHook(false, net.IPv4(192, 168, 0, 1), 8080, hookFunc)
|
hookID := manager.AddUDPPacketHook(false, netip.MustParseAddr("192.168.0.1"), 8080, hookFunc)
|
||||||
|
|
||||||
// Assert the hook is added by finding it in the manager's outgoing rules
|
// Assert the hook is added by finding it in the manager's outgoing rules
|
||||||
found := false
|
found := false
|
||||||
@@ -393,7 +392,7 @@ func TestRemovePacketHook(t *testing.T) {
|
|||||||
func TestProcessOutgoingHooks(t *testing.T) {
|
func TestProcessOutgoingHooks(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -401,7 +400,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
Mask: net.CIDRMask(16, 32),
|
Mask: net.CIDRMask(16, 32),
|
||||||
}
|
}
|
||||||
manager.udpTracker.Close()
|
manager.udpTracker.Close()
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(100*time.Millisecond, logger, flowLogger)
|
||||||
defer func() {
|
defer func() {
|
||||||
require.NoError(t, manager.Close(nil))
|
require.NoError(t, manager.Close(nil))
|
||||||
}()
|
}()
|
||||||
@@ -423,7 +422,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
hookCalled := false
|
hookCalled := false
|
||||||
hookID := manager.AddUDPPacketHook(
|
hookID := manager.AddUDPPacketHook(
|
||||||
false,
|
false,
|
||||||
net.ParseIP("100.10.0.100"),
|
netip.MustParseAddr("100.10.0.100"),
|
||||||
53,
|
53,
|
||||||
func([]byte) bool {
|
func([]byte) bool {
|
||||||
hookCalled = true
|
hookCalled = true
|
||||||
@@ -458,7 +457,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Test hook gets called
|
// Test hook gets called
|
||||||
result := manager.processOutgoingHooks(buf.Bytes())
|
result := manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.True(t, result)
|
require.True(t, result)
|
||||||
require.True(t, hookCalled)
|
require.True(t, hookCalled)
|
||||||
|
|
||||||
@@ -468,7 +467,7 @@ func TestProcessOutgoingHooks(t *testing.T) {
|
|||||||
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
err = gopacket.SerializeLayers(buf, opts, ipv4)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
result = manager.processOutgoingHooks(buf.Bytes())
|
result = manager.processOutgoingHooks(buf.Bytes(), 0)
|
||||||
require.False(t, result)
|
require.False(t, result)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -479,7 +478,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
ifaceMock := &IFaceMock{
|
ifaceMock := &IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}
|
}
|
||||||
manager, err := Create(ifaceMock, false)
|
manager, err := Create(ifaceMock, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
time.Sleep(time.Second)
|
time.Sleep(time.Second)
|
||||||
|
|
||||||
@@ -494,7 +493,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
start := time.Now()
|
start := time.Now()
|
||||||
for i := 0; i < testMax; i++ {
|
for i := 0; i < testMax; i++ {
|
||||||
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
port := &fw.Port{Values: []uint16{uint16(1000 + i)}}
|
||||||
_, err = manager.AddPeerFiltering(ip, "tcp", nil, port, fw.ActionAccept, "", "accept HTTP traffic")
|
_, err = manager.AddPeerFiltering(nil, ip, "tcp", nil, port, fw.ActionAccept, "")
|
||||||
|
|
||||||
require.NoError(t, err, "failed to add rule")
|
require.NoError(t, err, "failed to add rule")
|
||||||
}
|
}
|
||||||
@@ -506,7 +505,7 @@ func TestUSPFilterCreatePerformance(t *testing.T) {
|
|||||||
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
||||||
manager, err := Create(&IFaceMock{
|
manager, err := Create(&IFaceMock{
|
||||||
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
SetFilterFunc: func(device.PacketFilter) error { return nil },
|
||||||
}, false)
|
}, false, flowLogger)
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
manager.wgNetwork = &net.IPNet{
|
manager.wgNetwork = &net.IPNet{
|
||||||
@@ -515,7 +514,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
manager.udpTracker.Close() // Close the existing tracker
|
manager.udpTracker.Close() // Close the existing tracker
|
||||||
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger)
|
manager.udpTracker = conntrack.NewUDPTracker(200*time.Millisecond, logger, flowLogger)
|
||||||
manager.decoders = sync.Pool{
|
manager.decoders = sync.Pool{
|
||||||
New: func() any {
|
New: func() any {
|
||||||
d := &decoder{
|
d := &decoder{
|
||||||
@@ -534,8 +533,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
// Set up packet parameters
|
// Set up packet parameters
|
||||||
srcIP := net.ParseIP("100.10.0.1")
|
srcIP := netip.MustParseAddr("100.10.0.1")
|
||||||
dstIP := net.ParseIP("100.10.0.100")
|
dstIP := netip.MustParseAddr("100.10.0.100")
|
||||||
srcPort := uint16(51334)
|
srcPort := uint16(51334)
|
||||||
dstPort := uint16(53)
|
dstPort := uint16(53)
|
||||||
|
|
||||||
@@ -543,8 +542,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
outboundIPv4 := &layers.IPv4{
|
outboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: srcIP,
|
SrcIP: srcIP.AsSlice(),
|
||||||
DstIP: dstIP,
|
DstIP: dstIP.AsSlice(),
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
outboundUDP := &layers.UDP{
|
outboundUDP := &layers.UDP{
|
||||||
@@ -569,15 +568,15 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Process outbound packet and verify connection tracking
|
// Process outbound packet and verify connection tracking
|
||||||
drop := manager.DropOutgoing(outboundBuf.Bytes())
|
drop := manager.DropOutgoing(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Initial outbound packet should not be dropped")
|
require.False(t, drop, "Initial outbound packet should not be dropped")
|
||||||
|
|
||||||
// Verify connection was tracked
|
// Verify connection was tracked
|
||||||
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
conn, exists := manager.udpTracker.GetConnection(srcIP, srcPort, dstIP, dstPort)
|
||||||
|
|
||||||
require.True(t, exists, "Connection should be tracked after outbound packet")
|
require.True(t, exists, "Connection should be tracked after outbound packet")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(srcIP), conn.SourceIP), "Source IP should match")
|
require.True(t, srcIP.Compare(conn.SourceIP) == 0, "Source IP should match")
|
||||||
require.True(t, conntrack.ValidateIPs(conntrack.MakeIPAddr(dstIP), conn.DestIP), "Destination IP should match")
|
require.True(t, dstIP.Compare(conn.DestIP) == 0, "Destination IP should match")
|
||||||
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
require.Equal(t, srcPort, conn.SourcePort, "Source port should match")
|
||||||
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
require.Equal(t, dstPort, conn.DestPort, "Destination port should match")
|
||||||
|
|
||||||
@@ -585,8 +584,8 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
inboundIPv4 := &layers.IPv4{
|
inboundIPv4 := &layers.IPv4{
|
||||||
TTL: 64,
|
TTL: 64,
|
||||||
Version: 4,
|
Version: 4,
|
||||||
SrcIP: dstIP, // Original destination is now source
|
SrcIP: dstIP.AsSlice(), // Original destination is now source
|
||||||
DstIP: srcIP, // Original source is now destination
|
DstIP: srcIP.AsSlice(), // Original source is now destination
|
||||||
Protocol: layers.IPProtocolUDP,
|
Protocol: layers.IPProtocolUDP,
|
||||||
}
|
}
|
||||||
inboundUDP := &layers.UDP{
|
inboundUDP := &layers.UDP{
|
||||||
@@ -636,7 +635,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
for _, cp := range checkPoints {
|
for _, cp := range checkPoints {
|
||||||
time.Sleep(cp.sleep)
|
time.Sleep(cp.sleep)
|
||||||
|
|
||||||
drop = manager.dropFilter(inboundBuf.Bytes())
|
drop = manager.dropFilter(inboundBuf.Bytes(), 0)
|
||||||
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
require.Equal(t, cp.shouldAllow, !drop, cp.description)
|
||||||
|
|
||||||
// If the connection should still be valid, verify it exists
|
// If the connection should still be valid, verify it exists
|
||||||
@@ -685,7 +684,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Create a new outbound connection for invalid tests
|
// Create a new outbound connection for invalid tests
|
||||||
drop = manager.processOutgoingHooks(outboundBuf.Bytes())
|
drop = manager.processOutgoingHooks(outboundBuf.Bytes(), 0)
|
||||||
require.False(t, drop, "Second outbound packet should not be dropped")
|
require.False(t, drop, "Second outbound packet should not be dropped")
|
||||||
|
|
||||||
for _, tc := range invalidCases {
|
for _, tc := range invalidCases {
|
||||||
@@ -707,7 +706,7 @@ func TestStatefulFirewall_UDPTracking(t *testing.T) {
|
|||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
// Verify the invalid packet is dropped
|
// Verify the invalid packet is dropped
|
||||||
drop = manager.dropFilter(testBuf.Bytes())
|
drop = manager.dropFilter(testBuf.Bytes(), 0)
|
||||||
require.True(t, drop, tc.description)
|
require.True(t, drop, tc.description)
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -150,7 +150,7 @@ func isZeros(ip net.IP) bool {
|
|||||||
// NewUDPMuxDefault creates an implementation of UDPMux
|
// NewUDPMuxDefault creates an implementation of UDPMux
|
||||||
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
func NewUDPMuxDefault(params UDPMuxParams) *UDPMuxDefault {
|
||||||
if params.Logger == nil {
|
if params.Logger == nil {
|
||||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
params.Logger = getLogger()
|
||||||
}
|
}
|
||||||
|
|
||||||
mux := &UDPMuxDefault{
|
mux := &UDPMuxDefault{
|
||||||
@@ -455,3 +455,9 @@ func newBufferHolder(size int) *bufferHolder {
|
|||||||
buf: make([]byte, size),
|
buf: make([]byte, size),
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func getLogger() logging.LeveledLogger {
|
||||||
|
fac := logging.NewDefaultLoggerFactory()
|
||||||
|
//fac.Writer = log.StandardLogger().Writer()
|
||||||
|
return fac.NewLogger("ice")
|
||||||
|
}
|
||||||
|
|||||||
@@ -49,7 +49,7 @@ type UniversalUDPMuxParams struct {
|
|||||||
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
// NewUniversalUDPMuxDefault creates an implementation of UniversalUDPMux embedding UDPMux
|
||||||
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
func NewUniversalUDPMuxDefault(params UniversalUDPMuxParams) *UniversalUDPMuxDefault {
|
||||||
if params.Logger == nil {
|
if params.Logger == nil {
|
||||||
params.Logger = logging.NewDefaultLoggerFactory().NewLogger("ice")
|
params.Logger = getLogger()
|
||||||
}
|
}
|
||||||
if params.XORMappedAddrCacheTTL == 0 {
|
if params.XORMappedAddrCacheTTL == 0 {
|
||||||
params.XORMappedAddrCacheTTL = time.Second * 25
|
params.XORMappedAddrCacheTTL = time.Second * 25
|
||||||
|
|||||||
@@ -357,7 +357,7 @@ func toWgUserspaceString(wgCfg wgtypes.Config) string {
|
|||||||
|
|
||||||
func getFwmark() int {
|
func getFwmark() int {
|
||||||
if nbnet.AdvancedRouting() {
|
if nbnet.AdvancedRouting() {
|
||||||
return nbnet.NetbirdFwmark
|
return nbnet.ControlPlaneMark
|
||||||
}
|
}
|
||||||
return 0
|
return 0
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -2,6 +2,7 @@ package device
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"golang.zx2c4.com/wireguard/tun"
|
"golang.zx2c4.com/wireguard/tun"
|
||||||
@@ -10,16 +11,16 @@ import (
|
|||||||
// PacketFilter interface for firewall abilities
|
// PacketFilter interface for firewall abilities
|
||||||
type PacketFilter interface {
|
type PacketFilter interface {
|
||||||
// DropOutgoing filter outgoing packets from host to external destinations
|
// DropOutgoing filter outgoing packets from host to external destinations
|
||||||
DropOutgoing(packetData []byte) bool
|
DropOutgoing(packetData []byte, size int) bool
|
||||||
|
|
||||||
// DropIncoming filter incoming packets from external sources to host
|
// DropIncoming filter incoming packets from external sources to host
|
||||||
DropIncoming(packetData []byte) bool
|
DropIncoming(packetData []byte, size int) bool
|
||||||
|
|
||||||
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
// AddUDPPacketHook calls hook when UDP packet from given direction matched
|
||||||
//
|
//
|
||||||
// Hook function returns flag which indicates should be the matched package dropped or not.
|
// Hook function returns flag which indicates should be the matched package dropped or not.
|
||||||
// Hook function receives raw network packet data as argument.
|
// Hook function receives raw network packet data as argument.
|
||||||
AddUDPPacketHook(in bool, ip net.IP, dPort uint16, hook func(packet []byte) bool) string
|
AddUDPPacketHook(in bool, ip netip.Addr, dPort uint16, hook func(packet []byte) bool) string
|
||||||
|
|
||||||
// RemovePacketHook removes hook by ID
|
// RemovePacketHook removes hook by ID
|
||||||
RemovePacketHook(hookID string) error
|
RemovePacketHook(hookID string) error
|
||||||
@@ -57,7 +58,7 @@ func (d *FilteredDevice) Read(bufs [][]byte, sizes []int, offset int) (n int, er
|
|||||||
}
|
}
|
||||||
|
|
||||||
for i := 0; i < n; i++ {
|
for i := 0; i < n; i++ {
|
||||||
if filter.DropOutgoing(bufs[i][offset : offset+sizes[i]]) {
|
if filter.DropOutgoing(bufs[i][offset:offset+sizes[i]], sizes[i]) {
|
||||||
bufs = append(bufs[:i], bufs[i+1:]...)
|
bufs = append(bufs[:i], bufs[i+1:]...)
|
||||||
sizes = append(sizes[:i], sizes[i+1:]...)
|
sizes = append(sizes[:i], sizes[i+1:]...)
|
||||||
n--
|
n--
|
||||||
@@ -81,7 +82,7 @@ func (d *FilteredDevice) Write(bufs [][]byte, offset int) (int, error) {
|
|||||||
filteredBufs := make([][]byte, 0, len(bufs))
|
filteredBufs := make([][]byte, 0, len(bufs))
|
||||||
dropped := 0
|
dropped := 0
|
||||||
for _, buf := range bufs {
|
for _, buf := range bufs {
|
||||||
if !filter.DropIncoming(buf[offset:]) {
|
if !filter.DropIncoming(buf[offset:], len(buf)) {
|
||||||
filteredBufs = append(filteredBufs, buf)
|
filteredBufs = append(filteredBufs, buf)
|
||||||
dropped++
|
dropped++
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -146,7 +146,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
tun.EXPECT().Write(mockBufs, 0).Return(0, nil)
|
||||||
|
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropIncoming(gomock.Any()).Return(true)
|
filter.EXPECT().DropIncoming(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
@@ -201,7 +201,7 @@ func TestDeviceWrapperRead(t *testing.T) {
|
|||||||
return 1, nil
|
return 1, nil
|
||||||
})
|
})
|
||||||
filter := mocks.NewMockPacketFilter(ctrl)
|
filter := mocks.NewMockPacketFilter(ctrl)
|
||||||
filter.EXPECT().DropOutgoing(gomock.Any()).Return(true)
|
filter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).Return(true)
|
||||||
|
|
||||||
wrapped := newDeviceFilter(tun)
|
wrapped := newDeviceFilter(tun)
|
||||||
wrapped.filter = filter
|
wrapped.filter = filter
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ package mocks
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
net "net"
|
net "net"
|
||||||
|
"net/netip"
|
||||||
reflect "reflect"
|
reflect "reflect"
|
||||||
|
|
||||||
gomock "github.com/golang/mock/gomock"
|
gomock "github.com/golang/mock/gomock"
|
||||||
@@ -35,7 +36,7 @@ func (m *MockPacketFilter) EXPECT() *MockPacketFilterMockRecorder {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// AddUDPPacketHook mocks base method.
|
// AddUDPPacketHook mocks base method.
|
||||||
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 net.IP, arg2 uint16, arg3 func([]byte) bool) string {
|
func (m *MockPacketFilter) AddUDPPacketHook(arg0 bool, arg1 netip.Addr, arg2 uint16, arg3 func([]byte) bool) string {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
ret := m.ctrl.Call(m, "AddUDPPacketHook", arg0, arg1, arg2, arg3)
|
||||||
ret0, _ := ret[0].(string)
|
ret0, _ := ret[0].(string)
|
||||||
@@ -49,31 +50,31 @@ func (mr *MockPacketFilterMockRecorder) AddUDPPacketHook(arg0, arg1, arg2, arg3
|
|||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming mocks base method.
|
// DropIncoming mocks base method.
|
||||||
func (m *MockPacketFilter) DropIncoming(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropIncoming(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropIncoming", arg0)
|
ret := m.ctrl.Call(m, "DropIncoming", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropIncoming indicates an expected call of DropIncoming.
|
// DropIncoming indicates an expected call of DropIncoming.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropIncoming(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropIncoming", reflect.TypeOf((*MockPacketFilter)(nil).DropIncoming), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing mocks base method.
|
// DropOutgoing mocks base method.
|
||||||
func (m *MockPacketFilter) DropOutgoing(arg0 []byte) bool {
|
func (m *MockPacketFilter) DropOutgoing(arg0 []byte, arg1 int) bool {
|
||||||
m.ctrl.T.Helper()
|
m.ctrl.T.Helper()
|
||||||
ret := m.ctrl.Call(m, "DropOutgoing", arg0)
|
ret := m.ctrl.Call(m, "DropOutgoing", arg0, arg1)
|
||||||
ret0, _ := ret[0].(bool)
|
ret0, _ := ret[0].(bool)
|
||||||
return ret0
|
return ret0
|
||||||
}
|
}
|
||||||
|
|
||||||
// DropOutgoing indicates an expected call of DropOutgoing.
|
// DropOutgoing indicates an expected call of DropOutgoing.
|
||||||
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}) *gomock.Call {
|
func (mr *MockPacketFilterMockRecorder) DropOutgoing(arg0 interface{}, arg1 any) *gomock.Call {
|
||||||
mr.mock.ctrl.T.Helper()
|
mr.mock.ctrl.T.Helper()
|
||||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0)
|
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DropOutgoing", reflect.TypeOf((*MockPacketFilter)(nil).DropOutgoing), arg0, arg1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// RemovePacketHook mocks base method.
|
// RemovePacketHook mocks base method.
|
||||||
|
|||||||
@@ -28,6 +28,11 @@ type Manager interface {
|
|||||||
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
ApplyFiltering(networkMap *mgmProto.NetworkMap)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type protoMatch struct {
|
||||||
|
ips map[string]int
|
||||||
|
policyID []byte
|
||||||
|
}
|
||||||
|
|
||||||
// DefaultManager uses firewall manager to handle
|
// DefaultManager uses firewall manager to handle
|
||||||
type DefaultManager struct {
|
type DefaultManager struct {
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
@@ -240,7 +245,7 @@ func (d *DefaultManager) applyRouteACL(rule *mgmProto.RouteFirewallRule) (id.Rul
|
|||||||
|
|
||||||
dPorts := convertPortInfo(rule.PortInfo)
|
dPorts := convertPortInfo(rule.PortInfo)
|
||||||
|
|
||||||
addedRule, err := d.firewall.AddRouteFiltering(sources, destination, protocol, nil, dPorts, action)
|
addedRule, err := d.firewall.AddRouteFiltering(rule.PolicyID, sources, destination, protocol, nil, dPorts, action)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return "", fmt.Errorf("add route rule: %w", err)
|
return "", fmt.Errorf("add route rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -281,7 +286,7 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action, "")
|
ruleID := d.getPeerRuleID(ip, protocol, int(r.Direction), port, action)
|
||||||
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
if rulesPair, ok := d.peerRulesPairs[ruleID]; ok {
|
||||||
return ruleID, rulesPair, nil
|
return ruleID, rulesPair, nil
|
||||||
}
|
}
|
||||||
@@ -289,11 +294,11 @@ func (d *DefaultManager) protoRuleToFirewallRule(
|
|||||||
var rules []firewall.Rule
|
var rules []firewall.Rule
|
||||||
switch r.Direction {
|
switch r.Direction {
|
||||||
case mgmProto.RuleDirection_IN:
|
case mgmProto.RuleDirection_IN:
|
||||||
rules, err = d.addInRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addInRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
case mgmProto.RuleDirection_OUT:
|
case mgmProto.RuleDirection_OUT:
|
||||||
// TODO: Remove this soon. Outbound rules are obsolete.
|
// TODO: Remove this soon. Outbound rules are obsolete.
|
||||||
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
// We only maintain this for return traffic (inbound dir) which is now handled by the stateful firewall already
|
||||||
rules, err = d.addOutRules(ip, protocol, port, action, ipsetName, "")
|
rules, err = d.addOutRules(r.PolicyID, ip, protocol, port, action, ipsetName)
|
||||||
default:
|
default:
|
||||||
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
return "", nil, fmt.Errorf("invalid direction, skipping firewall rule")
|
||||||
}
|
}
|
||||||
@@ -322,14 +327,14 @@ func portInfoEmpty(portInfo *mgmProto.PortInfo) bool {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addInRules(
|
func (d *DefaultManager) addInRules(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, nil, port, action, ipsetName, comment)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, nil, port, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -338,18 +343,18 @@ func (d *DefaultManager) addInRules(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (d *DefaultManager) addOutRules(
|
func (d *DefaultManager) addOutRules(
|
||||||
|
id []byte,
|
||||||
ip net.IP,
|
ip net.IP,
|
||||||
protocol firewall.Protocol,
|
protocol firewall.Protocol,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
ipsetName string,
|
ipsetName string,
|
||||||
comment string,
|
|
||||||
) ([]firewall.Rule, error) {
|
) ([]firewall.Rule, error) {
|
||||||
if shouldSkipInvertedRule(protocol, port) {
|
if shouldSkipInvertedRule(protocol, port) {
|
||||||
return nil, nil
|
return nil, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
rule, err := d.firewall.AddPeerFiltering(ip, protocol, port, nil, action, ipsetName, comment)
|
rule, err := d.firewall.AddPeerFiltering(id, ip, protocol, port, nil, action, ipsetName)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("add firewall rule: %w", err)
|
return nil, fmt.Errorf("add firewall rule: %w", err)
|
||||||
}
|
}
|
||||||
@@ -364,9 +369,8 @@ func (d *DefaultManager) getPeerRuleID(
|
|||||||
direction int,
|
direction int,
|
||||||
port *firewall.Port,
|
port *firewall.Port,
|
||||||
action firewall.Action,
|
action firewall.Action,
|
||||||
comment string,
|
|
||||||
) id.RuleID {
|
) id.RuleID {
|
||||||
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action)) + comment
|
idStr := ip.String() + string(proto) + strconv.Itoa(direction) + strconv.Itoa(int(action))
|
||||||
if port != nil {
|
if port != nil {
|
||||||
idStr += port.String()
|
idStr += port.String()
|
||||||
}
|
}
|
||||||
@@ -389,10 +393,8 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
type protoMatch map[mgmProto.RuleProtocol]map[string]int
|
in := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||||
|
out := map[mgmProto.RuleProtocol]*protoMatch{}
|
||||||
in := protoMatch{}
|
|
||||||
out := protoMatch{}
|
|
||||||
|
|
||||||
// trace which type of protocols was squashed
|
// trace which type of protocols was squashed
|
||||||
squashedRules := []*mgmProto.FirewallRule{}
|
squashedRules := []*mgmProto.FirewallRule{}
|
||||||
@@ -405,14 +407,18 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
// 2. Any of rule contains Port.
|
// 2. Any of rule contains Port.
|
||||||
//
|
//
|
||||||
// We zeroed this to notify squash function that this protocol can't be squashed.
|
// We zeroed this to notify squash function that this protocol can't be squashed.
|
||||||
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols protoMatch) {
|
addRuleToCalculationMap := func(i int, r *mgmProto.FirewallRule, protocols map[mgmProto.RuleProtocol]*protoMatch) {
|
||||||
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
drop := r.Action == mgmProto.RuleAction_DROP || r.Port != ""
|
||||||
if drop {
|
if drop {
|
||||||
protocols[r.Protocol] = map[string]int{}
|
protocols[r.Protocol] = &protoMatch{ips: map[string]int{}}
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
if _, ok := protocols[r.Protocol]; !ok {
|
if _, ok := protocols[r.Protocol]; !ok {
|
||||||
protocols[r.Protocol] = map[string]int{}
|
protocols[r.Protocol] = &protoMatch{
|
||||||
|
ips: map[string]int{},
|
||||||
|
// store the first encountered PolicyID for this protocol
|
||||||
|
policyID: r.PolicyID,
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// special case, when we receive this all network IP address
|
// special case, when we receive this all network IP address
|
||||||
@@ -424,7 +430,7 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
ipset := protocols[r.Protocol]
|
ipset := protocols[r.Protocol].ips
|
||||||
|
|
||||||
if _, ok := ipset[r.PeerIP]; ok {
|
if _, ok := ipset[r.PeerIP]; ok {
|
||||||
return
|
return
|
||||||
@@ -450,9 +456,10 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
mgmProto.RuleProtocol_UDP,
|
mgmProto.RuleProtocol_UDP,
|
||||||
}
|
}
|
||||||
|
|
||||||
squash := func(matches protoMatch, direction mgmProto.RuleDirection) {
|
squash := func(matches map[mgmProto.RuleProtocol]*protoMatch, direction mgmProto.RuleDirection) {
|
||||||
for _, protocol := range protocolOrders {
|
for _, protocol := range protocolOrders {
|
||||||
if ipset, ok := matches[protocol]; !ok || len(ipset) != totalIPs || len(ipset) < 2 {
|
match, ok := matches[protocol]
|
||||||
|
if !ok || len(match.ips) != totalIPs || len(match.ips) < 2 {
|
||||||
// don't squash if :
|
// don't squash if :
|
||||||
// 1. Rules not cover all peers in the network
|
// 1. Rules not cover all peers in the network
|
||||||
// 2. Rules cover only one peer in the network.
|
// 2. Rules cover only one peer in the network.
|
||||||
@@ -465,6 +472,7 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
Direction: direction,
|
Direction: direction,
|
||||||
Action: mgmProto.RuleAction_ACCEPT,
|
Action: mgmProto.RuleAction_ACCEPT,
|
||||||
Protocol: protocol,
|
Protocol: protocol,
|
||||||
|
PolicyID: match.policyID,
|
||||||
})
|
})
|
||||||
squashedProtocols[protocol] = struct{}{}
|
squashedProtocols[protocol] = struct{}{}
|
||||||
|
|
||||||
@@ -493,9 +501,9 @@ func (d *DefaultManager) squashAcceptRules(
|
|||||||
// if we also have other not squashed rules.
|
// if we also have other not squashed rules.
|
||||||
for i, r := range networkMap.FirewallRules {
|
for i, r := range networkMap.FirewallRules {
|
||||||
if _, ok := squashedProtocols[r.Protocol]; ok {
|
if _, ok := squashedProtocols[r.Protocol]; ok {
|
||||||
if m, ok := in[r.Protocol]; ok && m[r.PeerIP] == i {
|
if m, ok := in[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||||
continue
|
continue
|
||||||
} else if m, ok := out[r.Protocol]; ok && m[r.PeerIP] == i {
|
} else if m, ok := out[r.Protocol]; ok && m.ips[r.PeerIP] == i {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -10,9 +10,12 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/firewall/manager"
|
"github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
"github.com/netbirdio/netbird/client/internal/acl/mocks"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
mgmProto "github.com/netbirdio/netbird/management/proto"
|
mgmProto "github.com/netbirdio/netbird/management/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
func TestDefaultManager(t *testing.T) {
|
func TestDefaultManager(t *testing.T) {
|
||||||
networkMap := &mgmProto.NetworkMap{
|
networkMap := &mgmProto.NetworkMap{
|
||||||
FirewallRules: []*mgmProto.FirewallRule{
|
FirewallRules: []*mgmProto.FirewallRule{
|
||||||
@@ -52,7 +55,7 @@ func TestDefaultManager(t *testing.T) {
|
|||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
@@ -346,7 +349,7 @@ func TestDefaultManagerEnableSSHRules(t *testing.T) {
|
|||||||
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
ifaceMock.EXPECT().GetWGDevice().Return(nil).AnyTimes()
|
||||||
|
|
||||||
// we receive one rule from the management so for testing purposes ignore it
|
// we receive one rule from the management so for testing purposes ignore it
|
||||||
fw, err := firewall.NewFirewall(ifaceMock, nil, false)
|
fw, err := firewall.NewFirewall(ifaceMock, nil, flowLogger, false)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Errorf("create firewall: %v", err)
|
t.Errorf("create firewall: %v", err)
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -94,12 +94,17 @@ func (p *PKCEAuthorizationFlow) RequestAuthInfo(ctx context.Context) (AuthFlowIn
|
|||||||
p.codeVerifier = codeVerifier
|
p.codeVerifier = codeVerifier
|
||||||
|
|
||||||
codeChallenge := createCodeChallenge(codeVerifier)
|
codeChallenge := createCodeChallenge(codeVerifier)
|
||||||
authURL := p.oAuthConfig.AuthCodeURL(
|
|
||||||
state,
|
params := []oauth2.AuthCodeOption{
|
||||||
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
oauth2.SetAuthURLParam("code_challenge_method", "S256"),
|
||||||
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
oauth2.SetAuthURLParam("code_challenge", codeChallenge),
|
||||||
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
oauth2.SetAuthURLParam("audience", p.providerConfig.Audience),
|
||||||
)
|
}
|
||||||
|
if !p.providerConfig.DisablePromptLogin {
|
||||||
|
params = append(params, oauth2.SetAuthURLParam("prompt", "login"))
|
||||||
|
}
|
||||||
|
|
||||||
|
authURL := p.oAuthConfig.AuthCodeURL(state, params...)
|
||||||
|
|
||||||
return AuthFlowInfo{
|
return AuthFlowInfo{
|
||||||
VerificationURIComplete: authURL,
|
VerificationURIComplete: authURL,
|
||||||
|
|||||||
49
client/internal/auth/pkce_flow_test.go
Normal file
49
client/internal/auth/pkce_flow_test.go
Normal file
@@ -0,0 +1,49 @@
|
|||||||
|
package auth
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestPromptLogin(t *testing.T) {
|
||||||
|
tt := []struct {
|
||||||
|
name string
|
||||||
|
prompt bool
|
||||||
|
}{
|
||||||
|
{"PromptLogin", true},
|
||||||
|
{"NoPromptLogin", false},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range tt {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
config := internal.PKCEAuthProviderConfig{
|
||||||
|
ClientID: "test-client-id",
|
||||||
|
Audience: "test-audience",
|
||||||
|
TokenEndpoint: "https://test-token-endpoint.com/token",
|
||||||
|
Scope: "openid email profile",
|
||||||
|
AuthorizationEndpoint: "https://test-auth-endpoint.com/authorize",
|
||||||
|
RedirectURLs: []string{"http://127.0.0.1:33992/"},
|
||||||
|
UseIDToken: true,
|
||||||
|
DisablePromptLogin: !tc.prompt,
|
||||||
|
}
|
||||||
|
pkce, err := NewPKCEAuthorizationFlow(config)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to create PKCEAuthorizationFlow: %v", err)
|
||||||
|
}
|
||||||
|
authInfo, err := pkce.RequestAuthInfo(context.Background())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("Failed to request auth info: %v", err)
|
||||||
|
}
|
||||||
|
pattern := "prompt=login"
|
||||||
|
if tc.prompt {
|
||||||
|
require.Contains(t, authInfo.VerificationURIComplete, pattern)
|
||||||
|
} else {
|
||||||
|
require.NotContains(t, authInfo.VerificationURIComplete, pattern)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -349,6 +349,25 @@ func (c *ConnectClient) Engine() *Engine {
|
|||||||
return e
|
return e
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// GetLatestNetworkMap returns the latest network map from the engine.
|
||||||
|
func (c *ConnectClient) GetLatestNetworkMap() (*mgmProto.NetworkMap, error) {
|
||||||
|
engine := c.Engine()
|
||||||
|
if engine == nil {
|
||||||
|
return nil, errors.New("engine is not initialized")
|
||||||
|
}
|
||||||
|
|
||||||
|
networkMap, err := engine.GetLatestNetworkMap()
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("get latest network map: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if networkMap == nil {
|
||||||
|
return nil, errors.New("network map is not available")
|
||||||
|
}
|
||||||
|
|
||||||
|
return networkMap, nil
|
||||||
|
}
|
||||||
|
|
||||||
// Status returns the current client status
|
// Status returns the current client status
|
||||||
func (c *ConnectClient) Status() StatusType {
|
func (c *ConnectClient) Status() StatusType {
|
||||||
if c == nil {
|
if c == nil {
|
||||||
|
|||||||
1022
client/internal/debug/debug.go
Normal file
1022
client/internal/debug/debug.go
Normal file
File diff suppressed because it is too large
Load Diff
@@ -1,9 +1,8 @@
|
|||||||
//go:build linux && !android
|
//go:build linux && !android
|
||||||
|
|
||||||
package server
|
package debug
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"archive/zip"
|
|
||||||
"bytes"
|
"bytes"
|
||||||
"encoding/binary"
|
"encoding/binary"
|
||||||
"fmt"
|
"fmt"
|
||||||
@@ -14,36 +13,31 @@ import (
|
|||||||
"github.com/google/nftables"
|
"github.com/google/nftables"
|
||||||
"github.com/google/nftables/expr"
|
"github.com/google/nftables/expr"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/anonymize"
|
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// addFirewallRules collects and adds firewall rules to the archive
|
// addFirewallRules collects and adds firewall rules to the archive
|
||||||
func (s *Server) addFirewallRules(req *proto.DebugBundleRequest, anonymizer *anonymize.Anonymizer, archive *zip.Writer) error {
|
func (g *BundleGenerator) addFirewallRules() error {
|
||||||
log.Info("Collecting firewall rules")
|
log.Info("Collecting firewall rules")
|
||||||
// Collect and add iptables rules
|
|
||||||
iptablesRules, err := collectIPTablesRules()
|
iptablesRules, err := collectIPTablesRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to collect iptables rules: %v", err)
|
log.Warnf("Failed to collect iptables rules: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if req.GetAnonymize() {
|
if g.anonymize {
|
||||||
iptablesRules = anonymizer.AnonymizeString(iptablesRules)
|
iptablesRules = g.anonymizer.AnonymizeString(iptablesRules)
|
||||||
}
|
}
|
||||||
if err := addFileToZip(archive, strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
if err := g.addFileToZip(strings.NewReader(iptablesRules), "iptables.txt"); err != nil {
|
||||||
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
log.Warnf("Failed to add iptables rules to bundle: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// Collect and add nftables rules
|
|
||||||
nftablesRules, err := collectNFTablesRules()
|
nftablesRules, err := collectNFTablesRules()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to collect nftables rules: %v", err)
|
log.Warnf("Failed to collect nftables rules: %v", err)
|
||||||
} else {
|
} else {
|
||||||
if req.GetAnonymize() {
|
if g.anonymize {
|
||||||
nftablesRules = anonymizer.AnonymizeString(nftablesRules)
|
nftablesRules = g.anonymizer.AnonymizeString(nftablesRules)
|
||||||
}
|
}
|
||||||
if err := addFileToZip(archive, strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
if err := g.addFileToZip(strings.NewReader(nftablesRules), "nftables.txt"); err != nil {
|
||||||
log.Warnf("Failed to add nftables rules to bundle: %v", err)
|
log.Warnf("Failed to add nftables rules to bundle: %v", err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -65,16 +59,13 @@ func collectIPTablesRules() (string, error) {
|
|||||||
builder.WriteString("\n")
|
builder.WriteString("\n")
|
||||||
}
|
}
|
||||||
|
|
||||||
// Then get verbose statistics for each table
|
|
||||||
builder.WriteString("=== iptables -v -n -L output ===\n")
|
builder.WriteString("=== iptables -v -n -L output ===\n")
|
||||||
|
|
||||||
// Get list of tables
|
|
||||||
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
tables := []string{"filter", "nat", "mangle", "raw", "security"}
|
||||||
|
|
||||||
for _, table := range tables {
|
for _, table := range tables {
|
||||||
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
builder.WriteString(fmt.Sprintf("*%s\n", table))
|
||||||
|
|
||||||
// Get verbose statistics for the entire table
|
|
||||||
stats, err := getTableStatistics(table)
|
stats, err := getTableStatistics(table)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
log.Warnf("Failed to get statistics for table %s: %v", table, err)
|
||||||
@@ -182,12 +173,10 @@ func formatTables(conn *nftables.Conn, tables []*nftables.Table) string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format chains
|
|
||||||
for _, chain := range chains {
|
for _, chain := range chains {
|
||||||
formatChain(conn, table, chain, &builder)
|
formatChain(conn, table, chain, &builder)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Format sets
|
|
||||||
if sets, err := conn.GetSets(table); err != nil {
|
if sets, err := conn.GetSets(table); err != nil {
|
||||||
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
|
log.Warnf("Failed to get sets for table %s: %v", table.Name, err)
|
||||||
} else if len(sets) > 0 {
|
} else if len(sets) > 0 {
|
||||||
7
client/internal/debug/debug_mobile.go
Normal file
7
client/internal/debug/debug_mobile.go
Normal file
@@ -0,0 +1,7 @@
|
|||||||
|
//go:build ios || android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addRoutes() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
8
client/internal/debug/debug_nonlinux.go
Normal file
8
client/internal/debug/debug_nonlinux.go
Normal file
@@ -0,0 +1,8 @@
|
|||||||
|
//go:build !linux || android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
// collectFirewallRules returns nothing on non-linux systems
|
||||||
|
func (g *BundleGenerator) addFirewallRules() error {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
25
client/internal/debug/debug_nonmobile.go
Normal file
25
client/internal/debug/debug_nonmobile.go
Normal file
@@ -0,0 +1,25 @@
|
|||||||
|
//go:build !ios && !android
|
||||||
|
|
||||||
|
package debug
|
||||||
|
|
||||||
|
import (
|
||||||
|
"fmt"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/routemanager/systemops"
|
||||||
|
)
|
||||||
|
|
||||||
|
func (g *BundleGenerator) addRoutes() error {
|
||||||
|
routes, err := systemops.GetRoutesFromTable()
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("get routes: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: get routes including nexthop
|
||||||
|
routesContent := formatRoutes(routes, g.anonymize, g.anonymizer)
|
||||||
|
routesReader := strings.NewReader(routesContent)
|
||||||
|
if err := g.addFileToZip(routesReader, "routes.txt"); err != nil {
|
||||||
|
return fmt.Errorf("add routes file to zip: %w", err)
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
@@ -1,4 +1,4 @@
|
|||||||
package server
|
package debug
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
@@ -239,7 +239,7 @@ func searchDomains(config HostDNSConfig) []string {
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
listOfDomains = append(listOfDomains, dConf.Domain)
|
listOfDomains = append(listOfDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
return listOfDomains
|
return listOfDomains
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -75,12 +75,7 @@ func (c *HandlerChain) AddHandler(pattern string, handler dns.Handler, priority
|
|||||||
}
|
}
|
||||||
|
|
||||||
// First remove any existing handler with same pattern (case-insensitive) and priority
|
// First remove any existing handler with same pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
c.removeEntry(origPattern, priority)
|
||||||
if strings.EqualFold(c.handlers[i].OrigPattern, origPattern) && c.handlers[i].Priority == priority {
|
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
// Check if handler implements SubdomainMatcher interface
|
// Check if handler implements SubdomainMatcher interface
|
||||||
matchSubdomains := false
|
matchSubdomains := false
|
||||||
@@ -133,30 +128,20 @@ func (c *HandlerChain) RemoveHandler(pattern string, priority int) {
|
|||||||
|
|
||||||
pattern = dns.Fqdn(pattern)
|
pattern = dns.Fqdn(pattern)
|
||||||
|
|
||||||
|
c.removeEntry(pattern, priority)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *HandlerChain) removeEntry(pattern string, priority int) {
|
||||||
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
// Find and remove handlers matching both original pattern (case-insensitive) and priority
|
||||||
for i := len(c.handlers) - 1; i >= 0; i-- {
|
for i := len(c.handlers) - 1; i >= 0; i-- {
|
||||||
entry := c.handlers[i]
|
entry := c.handlers[i]
|
||||||
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
if strings.EqualFold(entry.OrigPattern, pattern) && entry.Priority == priority {
|
||||||
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
c.handlers = append(c.handlers[:i], c.handlers[i+1:]...)
|
||||||
return
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
// HasHandlers returns true if there are any handlers remaining for the given pattern
|
|
||||||
func (c *HandlerChain) HasHandlers(pattern string) bool {
|
|
||||||
c.mu.RLock()
|
|
||||||
defer c.mu.RUnlock()
|
|
||||||
|
|
||||||
pattern = strings.ToLower(dns.Fqdn(pattern))
|
|
||||||
for _, entry := range c.handlers {
|
|
||||||
if strings.EqualFold(entry.Pattern, pattern) {
|
|
||||||
return true
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
|
|
||||||
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
func (c *HandlerChain) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
||||||
if len(r.Question) == 0 {
|
if len(r.Question) == 0 {
|
||||||
return
|
return
|
||||||
|
|||||||
@@ -443,14 +443,6 @@ func TestHandlerChain_PriorityDeregistration(t *testing.T) {
|
|||||||
for _, handler := range handlers {
|
for _, handler := range handlers {
|
||||||
handler.AssertExpectations(t)
|
handler.AssertExpectations(t)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verify handler exists check
|
|
||||||
for priority, shouldExist := range tt.expectedCalls {
|
|
||||||
if shouldExist {
|
|
||||||
assert.True(t, chain.HasHandlers(tt.ops[0].pattern),
|
|
||||||
"Handler chain should have handlers for pattern after removing priority %d", priority)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@@ -470,45 +462,69 @@ func TestHandlerChain_MultiPriorityHandling(t *testing.T) {
|
|||||||
r := new(dns.Msg)
|
r := new(dns.Msg)
|
||||||
r.SetQuestion(testQuery, dns.TypeA)
|
r.SetQuestion(testQuery, dns.TypeA)
|
||||||
|
|
||||||
|
// Keep track of mocks for the final assertion in Step 4
|
||||||
|
mocks := []*nbdns.MockSubdomainHandler{routeHandler, matchHandler, defaultHandler}
|
||||||
|
|
||||||
// Add handlers in mixed order
|
// Add handlers in mixed order
|
||||||
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
chain.AddHandler(testDomain, defaultHandler, nbdns.PriorityDefault)
|
||||||
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
chain.AddHandler(testDomain, routeHandler, nbdns.PriorityDNSRoute)
|
||||||
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
chain.AddHandler(testDomain, matchHandler, nbdns.PriorityMatchDomain)
|
||||||
|
|
||||||
// Test 1: Initial state with all three handlers
|
// Test 1: Initial state
|
||||||
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w1 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
// Highest priority handler (routeHandler) should be called
|
// Highest priority handler (routeHandler) should be called
|
||||||
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
routeHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
matchHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure others are not expected yet
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w1, r)
|
||||||
routeHandler.AssertExpectations(t)
|
routeHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
routeHandler.ExpectedCalls = nil
|
||||||
|
routeHandler.Calls = nil
|
||||||
|
matchHandler.ExpectedCalls = nil
|
||||||
|
matchHandler.Calls = nil
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 2: Remove highest priority handler
|
// Test 2: Remove highest priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDNSRoute)
|
||||||
assert.True(t, chain.HasHandlers(testDomain))
|
|
||||||
|
|
||||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w2 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
// Now middle priority handler (matchHandler) should be called
|
// Now middle priority handler (matchHandler) should be called
|
||||||
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
matchHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
defaultHandler.On("ServeDNS", mock.Anything, r).Maybe() // Ensure default is not expected yet
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w2, r)
|
||||||
matchHandler.AssertExpectations(t)
|
matchHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
matchHandler.ExpectedCalls = nil
|
||||||
|
matchHandler.Calls = nil
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 3: Remove middle priority handler
|
// Test 3: Remove middle priority handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
chain.RemoveHandler(testDomain, nbdns.PriorityMatchDomain)
|
||||||
assert.True(t, chain.HasHandlers(testDomain))
|
|
||||||
|
|
||||||
w = &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
w3 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
// Now lowest priority handler (defaultHandler) should be called
|
// Now lowest priority handler (defaultHandler) should be called
|
||||||
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
defaultHandler.On("ServeDNS", mock.Anything, r).Return().Once()
|
||||||
|
|
||||||
chain.ServeDNS(w, r)
|
chain.ServeDNS(w3, r)
|
||||||
defaultHandler.AssertExpectations(t)
|
defaultHandler.AssertExpectations(t)
|
||||||
|
|
||||||
|
defaultHandler.ExpectedCalls = nil
|
||||||
|
defaultHandler.Calls = nil
|
||||||
|
|
||||||
// Test 4: Remove last handler
|
// Test 4: Remove last handler
|
||||||
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
chain.RemoveHandler(testDomain, nbdns.PriorityDefault)
|
||||||
|
|
||||||
assert.False(t, chain.HasHandlers(testDomain))
|
w4 := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
chain.ServeDNS(w4, r) // Call ServeDNS on the now empty chain for this domain
|
||||||
|
|
||||||
|
for _, m := range mocks {
|
||||||
|
m.AssertNumberOfCalls(t, "ServeDNS", 0)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
func TestHandlerChain_CaseSensitivity(t *testing.T) {
|
||||||
@@ -830,3 +846,165 @@ func TestHandlerChain_DomainSpecificityOrdering(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestHandlerChain_AddRemoveRoundtrip(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
addPattern string
|
||||||
|
removePattern string
|
||||||
|
queryPattern string
|
||||||
|
shouldBeRemoved bool
|
||||||
|
description string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "exact same pattern",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "example.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing with identical patterns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "case difference",
|
||||||
|
addPattern: "Example.Com.",
|
||||||
|
removePattern: "EXAMPLE.COM.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with mixed case, removing with uppercase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reversed case difference",
|
||||||
|
addPattern: "EXAMPLE.ORG.",
|
||||||
|
removePattern: "example.org.",
|
||||||
|
queryPattern: "example.org.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with uppercase, removing with lowercase",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add wildcard, remove wildcard",
|
||||||
|
addPattern: "*.example.com.",
|
||||||
|
removePattern: "*.example.com.",
|
||||||
|
queryPattern: "sub.example.com.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing with identical wildcard patterns",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add wildcard, remove transformed pattern",
|
||||||
|
addPattern: "*.example.net.",
|
||||||
|
removePattern: "example.net.",
|
||||||
|
queryPattern: "sub.example.net.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding with wildcard, removing with non-wildcard pattern",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "add transformed pattern, remove wildcard",
|
||||||
|
addPattern: "example.io.",
|
||||||
|
removePattern: "*.example.io.",
|
||||||
|
queryPattern: "example.io.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding with non-wildcard pattern, removing with wildcard pattern",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "trailing dot difference",
|
||||||
|
addPattern: "example.dev",
|
||||||
|
removePattern: "example.dev.",
|
||||||
|
queryPattern: "example.dev.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding without trailing dot, removing with trailing dot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reversed trailing dot difference",
|
||||||
|
addPattern: "example.app.",
|
||||||
|
removePattern: "example.app",
|
||||||
|
queryPattern: "example.app.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding with trailing dot, removing without trailing dot",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "mixed case and wildcard",
|
||||||
|
addPattern: "*.Example.Site.",
|
||||||
|
removePattern: "*.EXAMPLE.SITE.",
|
||||||
|
queryPattern: "sub.example.site.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding mixed case wildcard, removing uppercase wildcard",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "root zone",
|
||||||
|
addPattern: ".",
|
||||||
|
removePattern: ".",
|
||||||
|
queryPattern: "random.domain.",
|
||||||
|
shouldBeRemoved: true,
|
||||||
|
description: "Adding and removing root zone",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "wrong domain",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "different.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding one domain, trying to remove a different domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "subdomain mismatch",
|
||||||
|
addPattern: "sub.example.com.",
|
||||||
|
removePattern: "example.com.",
|
||||||
|
queryPattern: "sub.example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding subdomain, trying to remove parent domain",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "parent domain mismatch",
|
||||||
|
addPattern: "example.com.",
|
||||||
|
removePattern: "sub.example.com.",
|
||||||
|
queryPattern: "example.com.",
|
||||||
|
shouldBeRemoved: false,
|
||||||
|
description: "Adding parent domain, trying to remove subdomain",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
chain := nbdns.NewHandlerChain()
|
||||||
|
|
||||||
|
handler := &nbdns.MockHandler{}
|
||||||
|
r := new(dns.Msg)
|
||||||
|
r.SetQuestion(tt.queryPattern, dns.TypeA)
|
||||||
|
w := &nbdns.ResponseWriterChain{ResponseWriter: &mockResponseWriter{}}
|
||||||
|
|
||||||
|
// First verify no handler is called before adding any
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
handler.AssertNotCalled(t, "ServeDNS")
|
||||||
|
|
||||||
|
// Add handler
|
||||||
|
chain.AddHandler(tt.addPattern, handler, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
// Verify handler is called after adding
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
|
||||||
|
// Reset mock for the next test
|
||||||
|
handler.ExpectedCalls = nil
|
||||||
|
|
||||||
|
// Remove handler
|
||||||
|
chain.RemoveHandler(tt.removePattern, nbdns.PriorityDefault)
|
||||||
|
|
||||||
|
// Set up expectations based on whether removal should succeed
|
||||||
|
if !tt.shouldBeRemoved {
|
||||||
|
handler.On("ServeDNS", mock.Anything, r).Once()
|
||||||
|
}
|
||||||
|
|
||||||
|
// Test if handler is still called after removal attempt
|
||||||
|
chain.ServeDNS(w, r)
|
||||||
|
|
||||||
|
if tt.shouldBeRemoved {
|
||||||
|
handler.AssertNotCalled(t, "ServeDNS",
|
||||||
|
"Handler should not be called after successful removal with pattern %q",
|
||||||
|
tt.removePattern)
|
||||||
|
} else {
|
||||||
|
handler.AssertExpectations(t)
|
||||||
|
handler.ExpectedCalls = nil
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
@@ -5,6 +5,8 @@ import (
|
|||||||
"net/netip"
|
"net/netip"
|
||||||
"strings"
|
"strings"
|
||||||
|
|
||||||
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
@@ -12,8 +14,8 @@ import (
|
|||||||
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
var ErrRouteAllWithoutNameserverGroup = fmt.Errorf("unable to configure DNS for this peer using file manager without a nameserver group with all domains configured")
|
||||||
|
|
||||||
const (
|
const (
|
||||||
ipv4ReverseZone = ".in-addr.arpa"
|
ipv4ReverseZone = ".in-addr.arpa."
|
||||||
ipv6ReverseZone = ".ip6.arpa"
|
ipv6ReverseZone = ".ip6.arpa."
|
||||||
)
|
)
|
||||||
|
|
||||||
type hostManager interface {
|
type hostManager interface {
|
||||||
@@ -103,7 +105,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
|
|
||||||
for _, domain := range nsConfig.Domains {
|
for _, domain := range nsConfig.Domains {
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.TrimSuffix(domain, "."),
|
Domain: strings.ToLower(dns.Fqdn(domain)),
|
||||||
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
MatchOnly: !nsConfig.SearchDomainsEnabled,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
@@ -112,7 +114,7 @@ func dnsConfigToHostDNSConfig(dnsConfig nbdns.Config, ip string, port int) HostD
|
|||||||
for _, customZone := range dnsConfig.CustomZones {
|
for _, customZone := range dnsConfig.CustomZones {
|
||||||
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
matchOnly := strings.HasSuffix(customZone.Domain, ipv4ReverseZone) || strings.HasSuffix(customZone.Domain, ipv6ReverseZone)
|
||||||
config.Domains = append(config.Domains, DomainConfig{
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
Domain: strings.TrimSuffix(customZone.Domain, "."),
|
Domain: strings.ToLower(dns.Fqdn(customZone.Domain)),
|
||||||
MatchOnly: matchOnly,
|
MatchOnly: matchOnly,
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -79,10 +79,10 @@ func (s *systemConfigurator) applyDNSConfig(config HostDNSConfig, stateManager *
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, dConf.Domain)
|
matchDomains = append(matchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, strings.TrimSuffix(""+dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
matchKey := getKeyWithInput(netbirdDNSStateKeyFormat, matchSuffix)
|
||||||
|
|||||||
@@ -17,15 +17,18 @@ import (
|
|||||||
|
|
||||||
var (
|
var (
|
||||||
userenv = syscall.NewLazyDLL("userenv.dll")
|
userenv = syscall.NewLazyDLL("userenv.dll")
|
||||||
|
dnsapi = syscall.NewLazyDLL("dnsapi.dll")
|
||||||
|
|
||||||
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
// https://learn.microsoft.com/en-us/windows/win32/api/userenv/nf-userenv-refreshpolicyex
|
||||||
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
refreshPolicyExFn = userenv.NewProc("RefreshPolicyEx")
|
||||||
|
|
||||||
|
dnsFlushResolverCacheFn = dnsapi.NewProc("DnsFlushResolverCache")
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
dnsPolicyConfigMatchPath = `SYSTEM\CurrentControlSet\Services\Dnscache\Parameters\DnsPolicyConfig\NetBird-Match`
|
||||||
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient`
|
gpoDnsPolicyRoot = `SOFTWARE\Policies\Microsoft\Windows NT\DNSClient\DnsPolicyConfig`
|
||||||
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\DnsPolicyConfig\NetBird-Match`
|
gpoDnsPolicyConfigMatchPath = gpoDnsPolicyRoot + `\NetBird-Match`
|
||||||
|
|
||||||
dnsPolicyConfigVersionKey = "Version"
|
dnsPolicyConfigVersionKey = "Version"
|
||||||
dnsPolicyConfigVersionValue = 2
|
dnsPolicyConfigVersionValue = 2
|
||||||
@@ -97,9 +100,9 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if !dConf.MatchOnly {
|
if !dConf.MatchOnly {
|
||||||
searchDomains = append(searchDomains, dConf.Domain)
|
searchDomains = append(searchDomains, strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
matchDomains = append(matchDomains, "."+dConf.Domain)
|
matchDomains = append(matchDomains, "."+strings.TrimSuffix(dConf.Domain, "."))
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(matchDomains) != 0 {
|
if len(matchDomains) != 0 {
|
||||||
@@ -116,6 +119,10 @@ func (r *registryConfigurator) applyDNSConfig(config HostDNSConfig, stateManager
|
|||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -136,10 +143,6 @@ func (r *registryConfigurator) addDNSMatchPolicy(domains []string, ip string) er
|
|||||||
return fmt.Errorf("configure GPO DNS policy: %w", err)
|
return fmt.Errorf("configure GPO DNS policy: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.configureDNSPolicy(dnsPolicyConfigMatchPath, domains, ip); err != nil {
|
|
||||||
return fmt.Errorf("configure local DNS policy: %w", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if err := refreshGroupPolicy(); err != nil {
|
if err := refreshGroupPolicy(); err != nil {
|
||||||
log.Warnf("failed to refresh group policy: %v", err)
|
log.Warnf("failed to refresh group policy: %v", err)
|
||||||
}
|
}
|
||||||
@@ -188,6 +191,26 @@ func (r *registryConfigurator) string() string {
|
|||||||
return "registry"
|
return "registry"
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (r *registryConfigurator) flushDNSCache() error {
|
||||||
|
// dnsFlushResolverCacheFn.Call() may panic if the func is not found
|
||||||
|
defer func() {
|
||||||
|
if rec := recover(); rec != nil {
|
||||||
|
log.Errorf("Recovered from panic in flushDNSCache: %v", rec)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
ret, _, err := dnsFlushResolverCacheFn.Call()
|
||||||
|
if ret == 0 {
|
||||||
|
if err != nil && !errors.Is(err, syscall.Errno(0)) {
|
||||||
|
return fmt.Errorf("DnsFlushResolverCache failed: %w", err)
|
||||||
|
}
|
||||||
|
return fmt.Errorf("DnsFlushResolverCache failed")
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Info("flushed DNS cache")
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
func (r *registryConfigurator) updateSearchDomains(domains []string) error {
|
||||||
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
if err := r.setInterfaceRegistryKeyStringValue(interfaceConfigSearchListKey, strings.Join(domains, ",")); err != nil {
|
||||||
return fmt.Errorf("update search domains: %w", err)
|
return fmt.Errorf("update search domains: %w", err)
|
||||||
@@ -240,6 +263,10 @@ func (r *registryConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("remove interface registry key: %w", err)
|
return fmt.Errorf("remove interface registry key: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := r.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -71,6 +71,12 @@ func (d *localResolver) lookupRecords(r *dns.Msg) []dns.RR {
|
|||||||
|
|
||||||
value, found := d.records.Load(key)
|
value, found := d.records.Load(key)
|
||||||
if !found {
|
if !found {
|
||||||
|
// alternatively check if we have a cname
|
||||||
|
if question.Qtype != dns.TypeCNAME {
|
||||||
|
r.Question[0].Qtype = dns.TypeCNAME
|
||||||
|
return d.lookupRecords(r)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -6,6 +6,7 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
|
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// MockServer is the mock instance of a dns server
|
// MockServer is the mock instance of a dns server
|
||||||
@@ -13,17 +14,17 @@ type MockServer struct {
|
|||||||
InitializeFunc func() error
|
InitializeFunc func() error
|
||||||
StopFunc func()
|
StopFunc func()
|
||||||
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
UpdateDNSServerFunc func(serial uint64, update nbdns.Config) error
|
||||||
RegisterHandlerFunc func([]string, dns.Handler, int)
|
RegisterHandlerFunc func(domain.List, dns.Handler, int)
|
||||||
DeregisterHandlerFunc func([]string, int)
|
DeregisterHandlerFunc func(domain.List, int)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
func (m *MockServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||||
if m.RegisterHandlerFunc != nil {
|
if m.RegisterHandlerFunc != nil {
|
||||||
m.RegisterHandlerFunc(domains, handler, priority)
|
m.RegisterHandlerFunc(domains, handler, priority)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *MockServer) DeregisterHandler(domains []string, priority int) {
|
func (m *MockServer) DeregisterHandler(domains domain.List, priority int) {
|
||||||
if m.DeregisterHandlerFunc != nil {
|
if m.DeregisterHandlerFunc != nil {
|
||||||
m.DeregisterHandlerFunc(domains, priority)
|
m.DeregisterHandlerFunc(domains, priority)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -13,7 +13,6 @@ import (
|
|||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/hashicorp/go-version"
|
"github.com/hashicorp/go-version"
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
@@ -126,10 +125,10 @@ func (n *networkManagerDbusConfigurator) applyDNSConfig(config HostDNSConfig, st
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if dConf.MatchOnly {
|
if dConf.MatchOnly {
|
||||||
matchDomains = append(matchDomains, "~."+dns.Fqdn(dConf.Domain))
|
matchDomains = append(matchDomains, "~."+dConf.Domain)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
searchDomains = append(searchDomains, dns.Fqdn(dConf.Domain))
|
searchDomains = append(searchDomains, dConf.Domain)
|
||||||
}
|
}
|
||||||
|
|
||||||
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
newDomainList := append(searchDomains, matchDomains...) //nolint:gocritic
|
||||||
|
|||||||
@@ -6,11 +6,13 @@ import (
|
|||||||
"fmt"
|
"fmt"
|
||||||
"net/netip"
|
"net/netip"
|
||||||
"runtime"
|
"runtime"
|
||||||
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
"github.com/mitchellh/hashstructure/v2"
|
"github.com/mitchellh/hashstructure/v2"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
"golang.org/x/exp/maps"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface/netstack"
|
"github.com/netbirdio/netbird/client/iface/netstack"
|
||||||
"github.com/netbirdio/netbird/client/internal/listener"
|
"github.com/netbirdio/netbird/client/internal/listener"
|
||||||
@@ -18,6 +20,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
cProto "github.com/netbirdio/netbird/client/proto"
|
cProto "github.com/netbirdio/netbird/client/proto"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
// ReadyListener is a notification mechanism what indicate the server is ready to handle host dns address changes
|
||||||
@@ -32,8 +35,8 @@ type IosDnsManager interface {
|
|||||||
|
|
||||||
// Server is a dns server interface
|
// Server is a dns server interface
|
||||||
type Server interface {
|
type Server interface {
|
||||||
RegisterHandler(domains []string, handler dns.Handler, priority int)
|
RegisterHandler(domains domain.List, handler dns.Handler, priority int)
|
||||||
DeregisterHandler(domains []string, priority int)
|
DeregisterHandler(domains domain.List, priority int)
|
||||||
Initialize() error
|
Initialize() error
|
||||||
Stop()
|
Stop()
|
||||||
DnsIP() string
|
DnsIP() string
|
||||||
@@ -65,6 +68,7 @@ type DefaultServer struct {
|
|||||||
previousConfigHash uint64
|
previousConfigHash uint64
|
||||||
currentConfig HostDNSConfig
|
currentConfig HostDNSConfig
|
||||||
handlerChain *HandlerChain
|
handlerChain *HandlerChain
|
||||||
|
extraDomains map[domain.Domain]int
|
||||||
|
|
||||||
// permanent related properties
|
// permanent related properties
|
||||||
permanent bool
|
permanent bool
|
||||||
@@ -164,13 +168,15 @@ func newDefaultServer(
|
|||||||
stateManager *statemanager.Manager,
|
stateManager *statemanager.Manager,
|
||||||
disableSys bool,
|
disableSys bool,
|
||||||
) *DefaultServer {
|
) *DefaultServer {
|
||||||
|
handlerChain := NewHandlerChain()
|
||||||
ctx, stop := context.WithCancel(ctx)
|
ctx, stop := context.WithCancel(ctx)
|
||||||
defaultServer := &DefaultServer{
|
defaultServer := &DefaultServer{
|
||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
ctxCancel: stop,
|
ctxCancel: stop,
|
||||||
disableSys: disableSys,
|
disableSys: disableSys,
|
||||||
service: dnsService,
|
service: dnsService,
|
||||||
handlerChain: NewHandlerChain(),
|
handlerChain: handlerChain,
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
dnsMuxMap: make(registeredHandlerMap),
|
dnsMuxMap: make(registeredHandlerMap),
|
||||||
localResolver: &localResolver{
|
localResolver: &localResolver{
|
||||||
registeredMap: make(registrationMap),
|
registeredMap: make(registrationMap),
|
||||||
@@ -181,14 +187,26 @@ func newDefaultServer(
|
|||||||
hostsDNSHolder: newHostsDNSHolder(),
|
hostsDNSHolder: newHostsDNSHolder(),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// register with root zone, handler chain takes care of the routing
|
||||||
|
dnsService.RegisterMux(".", handlerChain)
|
||||||
|
|
||||||
return defaultServer
|
return defaultServer
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) RegisterHandler(domains []string, handler dns.Handler, priority int) {
|
// RegisterHandler registers a handler for the given domains with the given priority.
|
||||||
|
// Any previously registered handler for the same domain and priority will be replaced.
|
||||||
|
func (s *DefaultServer) RegisterHandler(domains domain.List, handler dns.Handler, priority int) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
s.registerHandler(domains, handler, priority)
|
s.registerHandler(domains.ToPunycodeList(), handler, priority)
|
||||||
|
|
||||||
|
// TODO: This will take over zones for non-wildcard domains, for which we might not have a handler in the chain
|
||||||
|
for _, domain := range domains {
|
||||||
|
// convert to zone with simple ref counter
|
||||||
|
s.extraDomains[toZone(domain)]++
|
||||||
|
}
|
||||||
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, priority int) {
|
||||||
@@ -200,15 +218,23 @@ func (s *DefaultServer) registerHandler(domains []string, handler dns.Handler, p
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
s.handlerChain.AddHandler(domain, handler, priority)
|
s.handlerChain.AddHandler(domain, handler, priority)
|
||||||
s.service.RegisterMux(nbdns.NormalizeZone(domain), s.handlerChain)
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) DeregisterHandler(domains []string, priority int) {
|
// DeregisterHandler deregisters the handler for the given domains with the given priority.
|
||||||
|
func (s *DefaultServer) DeregisterHandler(domains domain.List, priority int) {
|
||||||
s.mux.Lock()
|
s.mux.Lock()
|
||||||
defer s.mux.Unlock()
|
defer s.mux.Unlock()
|
||||||
|
|
||||||
s.deregisterHandler(domains, priority)
|
s.deregisterHandler(domains.ToPunycodeList(), priority)
|
||||||
|
for _, domain := range domains {
|
||||||
|
zone := toZone(domain)
|
||||||
|
s.extraDomains[zone]--
|
||||||
|
if s.extraDomains[zone] <= 0 {
|
||||||
|
delete(s.extraDomains, zone)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
s.applyHostConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
||||||
@@ -221,11 +247,6 @@ func (s *DefaultServer) deregisterHandler(domains []string, priority int) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.handlerChain.RemoveHandler(domain, priority)
|
s.handlerChain.RemoveHandler(domain, priority)
|
||||||
|
|
||||||
// Only deregister from service if no handlers remain
|
|
||||||
if !s.handlerChain.HasHandlers(domain) {
|
|
||||||
s.service.DeregisterMux(nbdns.NormalizeZone(domain))
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -286,6 +307,8 @@ func (s *DefaultServer) Stop() {
|
|||||||
}
|
}
|
||||||
|
|
||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
|
|
||||||
|
maps.Clear(s.extraDomains)
|
||||||
}
|
}
|
||||||
|
|
||||||
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
// OnUpdatedHostDNSServer update the DNS servers addresses for root zones
|
||||||
@@ -390,7 +413,9 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
// is the service should be Disabled, we stop the listener or fake resolver
|
// is the service should be Disabled, we stop the listener or fake resolver
|
||||||
// and proceed with a regular update to clean up the handlers and records
|
// and proceed with a regular update to clean up the handlers and records
|
||||||
if update.ServiceEnable {
|
if update.ServiceEnable {
|
||||||
_ = s.service.Listen()
|
if err := s.service.Listen(); err != nil {
|
||||||
|
log.Errorf("failed to start DNS service: %v", err)
|
||||||
|
}
|
||||||
} else if !s.permanent {
|
} else if !s.permanent {
|
||||||
s.service.Stop()
|
s.service.Stop()
|
||||||
}
|
}
|
||||||
@@ -413,17 +438,13 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
|
|
||||||
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
s.currentConfig = dnsConfigToHostDNSConfig(update, s.service.RuntimeIP(), s.service.RuntimePort())
|
||||||
|
|
||||||
hostUpdate := s.currentConfig
|
|
||||||
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
if s.service.RuntimePort() != defaultPort && !s.hostManager.supportCustomPort() {
|
||||||
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
log.Warnf("the DNS manager of this peer doesn't support custom port. Disabling primary DNS setup. " +
|
||||||
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
"Learn more at: https://docs.netbird.io/how-to/manage-dns-in-your-network#local-resolver")
|
||||||
hostUpdate.RouteAll = false
|
s.currentConfig.RouteAll = false
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = s.hostManager.applyDNSConfig(hostUpdate, s.stateManager); err != nil {
|
s.applyHostConfig()
|
||||||
log.Error(err)
|
|
||||||
s.handleErrNoGroupaAll(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
// persist dns state right away
|
// persist dns state right away
|
||||||
@@ -441,6 +462,43 @@ func (s *DefaultServer) applyConfiguration(update nbdns.Config) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (s *DefaultServer) applyHostConfig() {
|
||||||
|
if s.hostManager == nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
// prevent reapplying config if we're shutting down
|
||||||
|
if s.ctx.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
config := s.currentConfig
|
||||||
|
|
||||||
|
existingDomains := make(map[string]struct{})
|
||||||
|
for _, d := range config.Domains {
|
||||||
|
existingDomains[d.Domain] = struct{}{}
|
||||||
|
}
|
||||||
|
|
||||||
|
// add extra domains only if they're not already in the config
|
||||||
|
for domain := range s.extraDomains {
|
||||||
|
domainStr := domain.PunycodeString()
|
||||||
|
|
||||||
|
if _, exists := existingDomains[domainStr]; !exists {
|
||||||
|
config.Domains = append(config.Domains, DomainConfig{
|
||||||
|
Domain: domainStr,
|
||||||
|
MatchOnly: true,
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Debugf("extra match domains: %v", s.extraDomains)
|
||||||
|
|
||||||
|
if err := s.hostManager.applyDNSConfig(config, s.stateManager); err != nil {
|
||||||
|
log.Errorf("failed to apply DNS host manager update: %v", err)
|
||||||
|
s.handleErrNoGroupaAll(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
func (s *DefaultServer) handleErrNoGroupaAll(err error) {
|
||||||
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
if !errors.Is(ErrRouteAllWithoutNameserverGroup, err) {
|
||||||
return
|
return
|
||||||
@@ -690,10 +748,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
s.applyHostConfig()
|
||||||
s.handleErrNoGroupaAll(err)
|
|
||||||
l.Errorf("Failed to apply nameserver deactivation on the host: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
if err := s.stateManager.PersistState(s.ctx); err != nil {
|
||||||
@@ -728,12 +783,7 @@ func (s *DefaultServer) upstreamCallbacks(
|
|||||||
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
s.registerHandler([]string{nbdns.RootZone}, handler, priority)
|
||||||
}
|
}
|
||||||
|
|
||||||
if s.hostManager != nil {
|
s.applyHostConfig()
|
||||||
if err := s.hostManager.applyDNSConfig(s.currentConfig, s.stateManager); err != nil {
|
|
||||||
s.handleErrNoGroupaAll(err)
|
|
||||||
l.WithError(err).Error("reactivate temporary disabled nameserver group, DNS update apply")
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
s.updateNSState(nsGroup, nil, true)
|
s.updateNSState(nsGroup, nil, true)
|
||||||
}
|
}
|
||||||
@@ -836,3 +886,13 @@ func groupNSGroupsByDomain(nsGroups []*nbdns.NameServerGroup) []nsGroupsByDomain
|
|||||||
|
|
||||||
return result
|
return result
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func toZone(d domain.Domain) domain.Domain {
|
||||||
|
return domain.Domain(
|
||||||
|
nbdns.NormalizeZone(
|
||||||
|
dns.Fqdn(
|
||||||
|
strings.ToLower(d.PunycodeString()),
|
||||||
|
),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,19 +23,23 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/iface/device"
|
"github.com/netbirdio/netbird/client/iface/device"
|
||||||
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
pfmock "github.com/netbirdio/netbird/client/iface/mocks"
|
||||||
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
"github.com/netbirdio/netbird/client/iface/wgaddr"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/statemanager"
|
"github.com/netbirdio/netbird/client/internal/statemanager"
|
||||||
"github.com/netbirdio/netbird/client/internal/stdnet"
|
"github.com/netbirdio/netbird/client/internal/stdnet"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
"github.com/netbirdio/netbird/formatter"
|
"github.com/netbirdio/netbird/formatter"
|
||||||
|
"github.com/netbirdio/netbird/management/domain"
|
||||||
)
|
)
|
||||||
|
|
||||||
|
var flowLogger = netflow.NewManager(nil, []byte{}, nil).GetLogger()
|
||||||
|
|
||||||
type mocWGIface struct {
|
type mocWGIface struct {
|
||||||
filter device.PacketFilter
|
filter device.PacketFilter
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Name() string {
|
func (w *mocWGIface) Name() string {
|
||||||
panic("implement me")
|
return "utun2301"
|
||||||
}
|
}
|
||||||
|
|
||||||
func (w *mocWGIface) Address() wgaddr.Address {
|
func (w *mocWGIface) Address() wgaddr.Address {
|
||||||
@@ -456,7 +460,7 @@ func TestDNSFakeResolverHandleUpdates(t *testing.T) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
packetfilter := pfmock.NewMockPacketFilter(ctrl)
|
||||||
packetfilter.EXPECT().DropOutgoing(gomock.Any()).AnyTimes()
|
packetfilter.EXPECT().DropOutgoing(gomock.Any(), gomock.Any()).AnyTimes()
|
||||||
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
packetfilter.EXPECT().AddUDPPacketHook(gomock.Any(), gomock.Any(), gomock.Any(), gomock.Any())
|
||||||
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
packetfilter.EXPECT().RemovePacketHook(gomock.Any())
|
||||||
packetfilter.EXPECT().SetNetwork(ipNet)
|
packetfilter.EXPECT().SetNetwork(ipNet)
|
||||||
@@ -917,7 +921,7 @@ func createWgInterfaceWithBind(t *testing.T) (*iface.WGIface, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
pf, err := uspfilter.Create(wgIface, false)
|
pf, err := uspfilter.Create(wgIface, false, flowLogger)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
t.Fatalf("failed to create uspfilter: %v", err)
|
t.Fatalf("failed to create uspfilter: %v", err)
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -1445,3 +1449,497 @@ func TestDefaultServer_UpdateMux(t *testing.T) {
|
|||||||
})
|
})
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func TestExtraDomains(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
initialConfig nbdns.Config
|
||||||
|
registerDomains []domain.List
|
||||||
|
deregisterDomains []domain.List
|
||||||
|
finalConfig nbdns.Config
|
||||||
|
expectedDomains []string
|
||||||
|
expectedMatchOnly []string
|
||||||
|
applyHostConfigCall int
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Register domains before config update",
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
|
},
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register domains after config update",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra1.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register overlapping domains",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "overlap.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "overlap.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"overlap.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register and deregister domains",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra2.example.com"},
|
||||||
|
{"extra3.example.com", "extra4.example.com"},
|
||||||
|
},
|
||||||
|
deregisterDomains: []domain.List{
|
||||||
|
{"extra1.example.com", "extra3.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra2.example.com.",
|
||||||
|
"extra4.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra2.example.com.",
|
||||||
|
"extra4.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register domains with ref counter",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "duplicate.example.com"},
|
||||||
|
{"other.example.com", "duplicate.example.com"},
|
||||||
|
},
|
||||||
|
deregisterDomains: []domain.List{
|
||||||
|
{"duplicate.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
"other.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
"other.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 4,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Config update with new domains after registration",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "duplicate.example.com"},
|
||||||
|
},
|
||||||
|
finalConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "newconfig.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"config.example.com.",
|
||||||
|
"newconfig.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
"duplicate.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Deregister domain that is part of customZones",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "protected.example.com"},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "protected.example.com"},
|
||||||
|
},
|
||||||
|
deregisterDomains: []domain.List{
|
||||||
|
{"protected.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
"config.example.com.",
|
||||||
|
"protected.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 3,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Register domain that is part of nameserver group",
|
||||||
|
initialConfig: nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
NameServerGroups: []*nbdns.NameServerGroup{
|
||||||
|
{
|
||||||
|
Domains: []string{"ns.example.com", "overlap.ns.example.com"},
|
||||||
|
NameServers: []nbdns.NameServer{
|
||||||
|
{
|
||||||
|
IP: netip.MustParseAddr("8.8.8.8"),
|
||||||
|
NSType: nbdns.UDPNameServerType,
|
||||||
|
Port: 53,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
registerDomains: []domain.List{
|
||||||
|
{"extra.example.com", "overlap.ns.example.com"},
|
||||||
|
},
|
||||||
|
expectedDomains: []string{
|
||||||
|
"ns.example.com.",
|
||||||
|
"overlap.ns.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
expectedMatchOnly: []string{
|
||||||
|
"ns.example.com.",
|
||||||
|
"overlap.ns.example.com.",
|
||||||
|
"extra.example.com.",
|
||||||
|
},
|
||||||
|
applyHostConfigCall: 2,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
var capturedConfigs []HostDNSConfig
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
capturedConfigs = append(capturedConfigs, config)
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
wgInterface: &mocWGIface{},
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply initial configuration
|
||||||
|
if tt.initialConfig.ServiceEnable {
|
||||||
|
err := server.applyConfiguration(tt.initialConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register domains
|
||||||
|
for _, domains := range tt.registerDomains {
|
||||||
|
server.RegisterHandler(domains, &MockHandler{}, PriorityDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Deregister domains if specified
|
||||||
|
for _, domains := range tt.deregisterDomains {
|
||||||
|
server.DeregisterHandler(domains, PriorityDefault)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Apply final configuration if specified
|
||||||
|
if tt.finalConfig.ServiceEnable {
|
||||||
|
err := server.applyConfiguration(tt.finalConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify number of calls
|
||||||
|
assert.Equal(t, tt.applyHostConfigCall, len(capturedConfigs),
|
||||||
|
"Expected %d calls to applyDNSConfig, got %d", tt.applyHostConfigCall, len(capturedConfigs))
|
||||||
|
|
||||||
|
// Get the last applied config
|
||||||
|
lastConfig := capturedConfigs[len(capturedConfigs)-1]
|
||||||
|
|
||||||
|
// Check all expected domains are present
|
||||||
|
domainMap := make(map[string]bool)
|
||||||
|
matchOnlyMap := make(map[string]bool)
|
||||||
|
|
||||||
|
for _, d := range lastConfig.Domains {
|
||||||
|
domainMap[d.Domain] = true
|
||||||
|
if d.MatchOnly {
|
||||||
|
matchOnlyMap[d.Domain] = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify expected domains
|
||||||
|
for _, d := range tt.expectedDomains {
|
||||||
|
assert.True(t, domainMap[d], "Expected domain %s not found in final config", d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify match-only domains
|
||||||
|
for _, d := range tt.expectedMatchOnly {
|
||||||
|
assert.True(t, matchOnlyMap[d], "Expected match-only domain %s not found in final config", d)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Verify no unexpected domains
|
||||||
|
assert.Equal(t, len(tt.expectedDomains), len(domainMap), "Unexpected number of domains in final config")
|
||||||
|
assert.Equal(t, len(tt.expectedMatchOnly), len(matchOnlyMap), "Unexpected number of match-only domains in final config")
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestExtraDomainsRefCounting(t *testing.T) {
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
// Register domains from different handlers with same domain
|
||||||
|
server.RegisterHandler(domain.List{"*.shared.example.com"}, &MockHandler{}, PriorityDNSRoute)
|
||||||
|
server.RegisterHandler(domain.List{"shared.example.com."}, &MockHandler{}, PriorityMatchDomain)
|
||||||
|
|
||||||
|
// Verify refcount is 2
|
||||||
|
zoneKey := toZone("shared.example.com")
|
||||||
|
assert.Equal(t, 2, server.extraDomains[zoneKey], "Refcount should be 2 after registering same domain twice")
|
||||||
|
|
||||||
|
// Deregister one handler
|
||||||
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityMatchDomain)
|
||||||
|
|
||||||
|
// Verify refcount is 1
|
||||||
|
assert.Equal(t, 1, server.extraDomains[zoneKey], "Refcount should be 1 after deregistering one handler")
|
||||||
|
|
||||||
|
// Deregister the other handler
|
||||||
|
server.DeregisterHandler(domain.List{"shared.example.com"}, PriorityDNSRoute)
|
||||||
|
|
||||||
|
// Verify domain is removed
|
||||||
|
_, exists := server.extraDomains[zoneKey]
|
||||||
|
assert.False(t, exists, "Domain should be removed after deregistering all handlers")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestUpdateConfigWithExistingExtraDomains(t *testing.T) {
|
||||||
|
var capturedConfig HostDNSConfig
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
capturedConfig = config
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterHandler(domain.List{"extra.example.com"}, &MockHandler{}, PriorityDefault)
|
||||||
|
|
||||||
|
initialConfig := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := server.applyConfiguration(initialConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
}
|
||||||
|
assert.Contains(t, domains, "config.example.com.")
|
||||||
|
assert.Contains(t, domains, "extra.example.com.")
|
||||||
|
|
||||||
|
// Now apply a new configuration with overlapping domain
|
||||||
|
updatedConfig := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
{Domain: "extra.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err = server.applyConfiguration(updatedConfig)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
// Verify both domains are in config, but no duplicates
|
||||||
|
domains = []string{}
|
||||||
|
matchOnlyCount := 0
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
if d.MatchOnly {
|
||||||
|
matchOnlyCount++
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
assert.Contains(t, domains, "config.example.com.")
|
||||||
|
assert.Contains(t, domains, "extra.example.com.")
|
||||||
|
assert.Equal(t, 2, len(domains), "Should have exactly 2 domains with no duplicates")
|
||||||
|
|
||||||
|
// Extra domain should no longer be marked as match-only when in config
|
||||||
|
matchOnlyDomain := ""
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
if d.Domain == "extra.example.com." && d.MatchOnly {
|
||||||
|
matchOnlyDomain = d.Domain
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
assert.Empty(t, matchOnlyDomain, "Domain should not be match-only when included in config")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestDomainCaseHandling(t *testing.T) {
|
||||||
|
var capturedConfig HostDNSConfig
|
||||||
|
mockHostConfig := &mockHostConfigurator{
|
||||||
|
applyDNSConfigFunc: func(config HostDNSConfig, _ *statemanager.Manager) error {
|
||||||
|
capturedConfig = config
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
restoreHostDNSFunc: func() error {
|
||||||
|
return nil
|
||||||
|
},
|
||||||
|
supportCustomPortFunc: func() bool {
|
||||||
|
return true
|
||||||
|
},
|
||||||
|
stringFunc: func() string {
|
||||||
|
return "mock"
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
mockSvc := &mockService{}
|
||||||
|
server := &DefaultServer{
|
||||||
|
ctx: context.Background(),
|
||||||
|
handlerChain: NewHandlerChain(),
|
||||||
|
hostManager: mockHostConfig,
|
||||||
|
localResolver: &localResolver{},
|
||||||
|
service: mockSvc,
|
||||||
|
statusRecorder: peer.NewRecorder("test"),
|
||||||
|
extraDomains: make(map[domain.Domain]int),
|
||||||
|
}
|
||||||
|
|
||||||
|
server.RegisterHandler(domain.List{"MIXED.example.com"}, &MockHandler{}, PriorityDefault)
|
||||||
|
server.RegisterHandler(domain.List{"mixed.EXAMPLE.com"}, &MockHandler{}, PriorityMatchDomain)
|
||||||
|
|
||||||
|
assert.Equal(t, 1, len(server.extraDomains), "Case differences should be normalized")
|
||||||
|
|
||||||
|
config := nbdns.Config{
|
||||||
|
ServiceEnable: true,
|
||||||
|
CustomZones: []nbdns.CustomZone{
|
||||||
|
{Domain: "config.example.com"},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
err := server.applyConfiguration(config)
|
||||||
|
assert.NoError(t, err)
|
||||||
|
|
||||||
|
var domains []string
|
||||||
|
for _, d := range capturedConfig.Domains {
|
||||||
|
domains = append(domains, d.Domain)
|
||||||
|
}
|
||||||
|
assert.Contains(t, domains, "config.example.com.", "Mixed case domain should be normalized and pre.sent")
|
||||||
|
assert.Contains(t, domains, "mixed.example.com.", "Mixed case domain should be normalized and present")
|
||||||
|
}
|
||||||
|
|||||||
@@ -2,7 +2,7 @@ package dns
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"fmt"
|
"fmt"
|
||||||
"net"
|
"net/netip"
|
||||||
"sync"
|
"sync"
|
||||||
|
|
||||||
"github.com/google/gopacket"
|
"github.com/google/gopacket"
|
||||||
@@ -117,5 +117,10 @@ func (s *ServiceViaMemory) filterDNSTraffic() (string, error) {
|
|||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
||||||
return filter.AddUDPPacketHook(false, net.ParseIP(s.runtimeIP), uint16(s.runtimePort), hook), nil
|
ip, err := netip.ParseAddr(s.runtimeIP)
|
||||||
|
if err != nil {
|
||||||
|
return "", fmt.Errorf("parse runtime ip: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return filter.AddUDPPacketHook(false, ip, uint16(s.runtimePort), hook), nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -11,7 +11,6 @@ import (
|
|||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/godbus/dbus/v5"
|
"github.com/godbus/dbus/v5"
|
||||||
"github.com/miekg/dns"
|
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
"golang.org/x/sys/unix"
|
"golang.org/x/sys/unix"
|
||||||
|
|
||||||
@@ -38,7 +37,6 @@ const (
|
|||||||
|
|
||||||
type systemdDbusConfigurator struct {
|
type systemdDbusConfigurator struct {
|
||||||
dbusLinkObject dbus.ObjectPath
|
dbusLinkObject dbus.ObjectPath
|
||||||
routingAll bool
|
|
||||||
ifaceName string
|
ifaceName string
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -112,7 +110,7 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||||
Domain: dns.Fqdn(dConf.Domain),
|
Domain: dConf.Domain,
|
||||||
MatchOnly: dConf.MatchOnly,
|
MatchOnly: dConf.MatchOnly,
|
||||||
})
|
})
|
||||||
|
|
||||||
@@ -124,18 +122,19 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
}
|
}
|
||||||
|
|
||||||
if config.RouteAll {
|
if config.RouteAll {
|
||||||
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
|
||||||
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, true)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting link as default dns router, failed with error: %w", err)
|
return fmt.Errorf("set link as default dns router: %w", err)
|
||||||
}
|
}
|
||||||
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
domainsInput = append(domainsInput, systemdDbusLinkDomainsInput{
|
||||||
Domain: nbdns.RootZone,
|
Domain: nbdns.RootZone,
|
||||||
MatchOnly: true,
|
MatchOnly: true,
|
||||||
})
|
})
|
||||||
s.routingAll = true
|
log.Infof("configured %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
||||||
} else if s.routingAll {
|
} else {
|
||||||
log.Infof("removing %s:%d as main DNS forwarder for this peer", config.ServerIP, config.ServerPort)
|
if err = s.callLinkMethod(systemdDbusSetDefaultRouteMethodSuffix, false); err != nil {
|
||||||
|
return fmt.Errorf("remove link as default dns router: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
state := &ShutdownState{
|
state := &ShutdownState{
|
||||||
@@ -151,6 +150,11 @@ func (s *systemdDbusConfigurator) applyDNSConfig(config HostDNSConfig, stateMana
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
log.Error(err)
|
log.Error(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if err := s.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -163,7 +167,8 @@ func (s *systemdDbusConfigurator) setDomainsForInterface(domainsInput []systemdD
|
|||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
return fmt.Errorf("setting domains configuration failed with error: %w", err)
|
||||||
}
|
}
|
||||||
return s.flushCaches()
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
||||||
@@ -183,10 +188,14 @@ func (s *systemdDbusConfigurator) restoreHostDNS() error {
|
|||||||
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
return fmt.Errorf("unable to revert link configuration, got error: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return s.flushCaches()
|
if err := s.flushDNSCache(); err != nil {
|
||||||
|
log.Errorf("failed to flush DNS cache: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *systemdDbusConfigurator) flushCaches() error {
|
func (s *systemdDbusConfigurator) flushDNSCache() error {
|
||||||
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
obj, closeConn, err := getDbusObject(systemdResolvedDest, systemdDbusObjectNode)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
return fmt.Errorf("attempting to retrieve the object %s, err: %w", systemdDbusObjectNode, err)
|
||||||
|
|||||||
@@ -18,14 +18,16 @@ import (
|
|||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/proto"
|
"github.com/netbirdio/netbird/client/proto"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
|
UpstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
failsTillDeact = int32(5)
|
failsTillDeact = int32(5)
|
||||||
reactivatePeriod = 30 * time.Second
|
reactivatePeriod = 30 * time.Second
|
||||||
upstreamTimeout = 15 * time.Second
|
|
||||||
probeTimeout = 2 * time.Second
|
probeTimeout = 2 * time.Second
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -66,7 +68,7 @@ func newUpstreamResolverBase(ctx context.Context, statusRecorder *peer.Status, d
|
|||||||
ctx: ctx,
|
ctx: ctx,
|
||||||
cancel: cancel,
|
cancel: cancel,
|
||||||
domain: domain,
|
domain: domain,
|
||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
statusRecorder: statusRecorder,
|
statusRecorder: statusRecorder,
|
||||||
@@ -106,9 +108,7 @@ func (u *upstreamResolverBase) ServeDNS(w dns.ResponseWriter, r *dns.Msg) {
|
|||||||
}()
|
}()
|
||||||
|
|
||||||
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
log.Tracef("received upstream question: domain=%s type=%v class=%v", r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
// set the AuthenticatedData flag and the EDNS0 buffer size to 4096 bytes to support larger dns records
|
|
||||||
if r.Extra == nil {
|
if r.Extra == nil {
|
||||||
r.SetEdns0(4096, false)
|
|
||||||
r.MsgHdr.AuthenticatedData = true
|
r.MsgHdr.AuthenticatedData = true
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -336,3 +336,51 @@ func (u *upstreamResolverBase) testNameserver(server string, timeout time.Durati
|
|||||||
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
_, _, err := u.upstreamClient.exchange(ctx, server, r)
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// ExchangeWithFallback exchanges a DNS message with the upstream server.
|
||||||
|
// It first tries to use UDP, and if it is truncated, it falls back to TCP.
|
||||||
|
// If the passed context is nil, this will use Exchange instead of ExchangeContext.
|
||||||
|
func ExchangeWithFallback(ctx context.Context, client *dns.Client, r *dns.Msg, upstream string) (*dns.Msg, time.Duration, error) {
|
||||||
|
// MTU - ip + udp headers
|
||||||
|
// Note: this could be sent out on an interface that is not ours, but our MTU should always be lower.
|
||||||
|
client.UDPSize = iface.DefaultMTU - (60 + 8)
|
||||||
|
|
||||||
|
var (
|
||||||
|
rm *dns.Msg
|
||||||
|
t time.Duration
|
||||||
|
err error
|
||||||
|
)
|
||||||
|
|
||||||
|
if ctx == nil {
|
||||||
|
rm, t, err = client.Exchange(r, upstream)
|
||||||
|
} else {
|
||||||
|
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, t, fmt.Errorf("with udp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if rm == nil || !rm.MsgHdr.Truncated {
|
||||||
|
return rm, t, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Tracef("udp response for domain=%s type=%v class=%v is truncated, trying TCP.",
|
||||||
|
r.Question[0].Name, r.Question[0].Qtype, r.Question[0].Qclass)
|
||||||
|
|
||||||
|
client.Net = "tcp"
|
||||||
|
|
||||||
|
if ctx == nil {
|
||||||
|
rm, t, err = client.Exchange(r, upstream)
|
||||||
|
} else {
|
||||||
|
rm, t, err = client.ExchangeContext(ctx, r, upstream)
|
||||||
|
}
|
||||||
|
|
||||||
|
if err != nil {
|
||||||
|
return nil, t, fmt.Errorf("with tcp: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// TODO: once TCP is implemented, rm.Truncate() if the request came in over UDP
|
||||||
|
|
||||||
|
return rm, t, nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -55,7 +55,7 @@ func (u *upstreamResolver) exchangeWithinVPN(ctx context.Context, upstream strin
|
|||||||
|
|
||||||
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
// exchangeWithoutVPN protect the UDP socket by Android SDK to avoid to goes through the VPN
|
||||||
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchangeWithoutVPN(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
timeout := upstreamTimeout
|
timeout := UpstreamTimeout
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
timeout = time.Until(deadline)
|
timeout = time.Until(deadline)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -34,6 +34,5 @@ func newUpstreamResolver(
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
func (u *upstreamResolver) exchange(ctx context.Context, upstream string, r *dns.Msg) (rm *dns.Msg, t time.Duration, err error) {
|
||||||
upstreamExchangeClient := &dns.Client{}
|
return ExchangeWithFallback(ctx, &dns.Client{}, r, upstream)
|
||||||
return upstreamExchangeClient.ExchangeContext(ctx, r, upstream)
|
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -52,7 +52,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
return nil, 0, fmt.Errorf("error while parsing upstream host: %s", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
timeout := upstreamTimeout
|
timeout := UpstreamTimeout
|
||||||
if deadline, ok := ctx.Deadline(); ok {
|
if deadline, ok := ctx.Deadline(); ok {
|
||||||
timeout = time.Until(deadline)
|
timeout = time.Until(deadline)
|
||||||
}
|
}
|
||||||
@@ -68,7 +68,7 @@ func (u *upstreamResolverIOS) exchange(ctx context.Context, upstream string, r *
|
|||||||
}
|
}
|
||||||
|
|
||||||
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
// Cannot use client.ExchangeContext because it overwrites our Dialer
|
||||||
return client.Exchange(r, upstream)
|
return ExchangeWithFallback(nil, client, r, upstream)
|
||||||
}
|
}
|
||||||
|
|
||||||
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
// GetClientPrivate returns a new DNS client bound to the local IP address of the Netbird interface
|
||||||
|
|||||||
@@ -26,7 +26,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
name: "Should Resolve A Record",
|
name: "Should Resolve A Record",
|
||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
InputServers: []string{"8.8.8.8:53", "8.8.4.4:53"},
|
||||||
timeout: upstreamTimeout,
|
timeout: UpstreamTimeout,
|
||||||
expectedAnswer: "1.1.1.1",
|
expectedAnswer: "1.1.1.1",
|
||||||
},
|
},
|
||||||
{
|
{
|
||||||
@@ -48,7 +48,7 @@ func TestUpstreamResolver_ServeDNS(t *testing.T) {
|
|||||||
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
inputMSG: new(dns.Msg).SetQuestion("one.one.one.one.", dns.TypeA),
|
||||||
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
InputServers: []string{"8.0.0.0:53", "8.8.4.4:53"},
|
||||||
cancelCTX: true,
|
cancelCTX: true,
|
||||||
timeout: upstreamTimeout,
|
timeout: UpstreamTimeout,
|
||||||
responseShouldBeNil: true,
|
responseShouldBeNil: true,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
@@ -122,7 +122,7 @@ func TestUpstreamResolver_DeactivationReactivation(t *testing.T) {
|
|||||||
r: new(dns.Msg),
|
r: new(dns.Msg),
|
||||||
rtt: time.Millisecond,
|
rtt: time.Millisecond,
|
||||||
},
|
},
|
||||||
upstreamTimeout: upstreamTimeout,
|
upstreamTimeout: UpstreamTimeout,
|
||||||
reactivatePeriod: reactivatePeriod,
|
reactivatePeriod: reactivatePeriod,
|
||||||
failsTillDeact: failsTillDeact,
|
failsTillDeact: failsTillDeact,
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -3,34 +3,45 @@ package dnsfwd
|
|||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
|
"math"
|
||||||
"net"
|
"net"
|
||||||
|
"net/netip"
|
||||||
|
"strings"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
|
||||||
"github.com/miekg/dns"
|
"github.com/miekg/dns"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
nbdns "github.com/netbirdio/netbird/dns"
|
nbdns "github.com/netbirdio/netbird/dns"
|
||||||
)
|
)
|
||||||
|
|
||||||
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
const errResolveFailed = "failed to resolve query for domain=%s: %v"
|
||||||
|
const upstreamTimeout = 15 * time.Second
|
||||||
|
|
||||||
type DNSForwarder struct {
|
type DNSForwarder struct {
|
||||||
listenAddress string
|
listenAddress string
|
||||||
ttl uint32
|
ttl uint32
|
||||||
domains []string
|
domains []string
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
dnsServer *dns.Server
|
dnsServer *dns.Server
|
||||||
mux *dns.ServeMux
|
mux *dns.ServeMux
|
||||||
|
|
||||||
|
resId sync.Map
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewDNSForwarder(listenAddress string, ttl uint32) *DNSForwarder {
|
func NewDNSForwarder(listenAddress string, ttl uint32, statusRecorder *peer.Status) *DNSForwarder {
|
||||||
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
log.Debugf("creating DNS forwarder with listen_address=%s ttl=%d", listenAddress, ttl)
|
||||||
return &DNSForwarder{
|
return &DNSForwarder{
|
||||||
listenAddress: listenAddress,
|
listenAddress: listenAddress,
|
||||||
ttl: ttl,
|
ttl: ttl,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) Listen(domains []string) error {
|
func (f *DNSForwarder) Listen(domains []string, resIds map[string]string) error {
|
||||||
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
log.Infof("listen DNS forwarder on address=%s", f.listenAddress)
|
||||||
mux := dns.NewServeMux()
|
mux := dns.NewServeMux()
|
||||||
|
|
||||||
@@ -42,22 +53,30 @@ func (f *DNSForwarder) Listen(domains []string) error {
|
|||||||
f.dnsServer = dnsServer
|
f.dnsServer = dnsServer
|
||||||
f.mux = mux
|
f.mux = mux
|
||||||
|
|
||||||
f.UpdateDomains(domains)
|
f.UpdateDomains(domains, resIds)
|
||||||
|
|
||||||
return dnsServer.ListenAndServe()
|
return dnsServer.ListenAndServe()
|
||||||
}
|
}
|
||||||
|
|
||||||
func (f *DNSForwarder) UpdateDomains(domains []string) {
|
func (f *DNSForwarder) UpdateDomains(domains []string, resIds map[string]string) {
|
||||||
log.Debugf("Updating domains from %v to %v", f.domains, domains)
|
log.Debugf("Updating domains from %v to %v", f.domains, domains)
|
||||||
|
|
||||||
for _, d := range f.domains {
|
for _, d := range f.domains {
|
||||||
f.mux.HandleRemove(d)
|
f.mux.HandleRemove(d)
|
||||||
}
|
}
|
||||||
|
f.resId.Clear()
|
||||||
|
|
||||||
newDomains := filterDomains(domains)
|
newDomains := filterDomains(domains)
|
||||||
for _, d := range newDomains {
|
for _, d := range newDomains {
|
||||||
f.mux.HandleFunc(d, f.handleDNSQuery)
|
f.mux.HandleFunc(d, f.handleDNSQuery)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
for domain, resId := range resIds {
|
||||||
|
if domain != "" {
|
||||||
|
f.resId.Store(domain, resId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
f.domains = newDomains
|
f.domains = newDomains
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -79,9 +98,54 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
domain := question.Name
|
domain := question.Name
|
||||||
|
|
||||||
resp := query.SetReply(query)
|
resp := query.SetReply(query)
|
||||||
|
var network string
|
||||||
|
switch question.Qtype {
|
||||||
|
case dns.TypeA:
|
||||||
|
network = "ip4"
|
||||||
|
case dns.TypeAAAA:
|
||||||
|
network = "ip6"
|
||||||
|
default:
|
||||||
|
// TODO: Handle other types
|
||||||
|
|
||||||
ips, err := net.LookupIP(domain)
|
resp.Rcode = dns.RcodeNotImplemented
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), upstreamTimeout)
|
||||||
|
defer cancel()
|
||||||
|
ips, err := net.DefaultResolver.LookupNetIP(ctx, network, domain)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
f.handleDNSError(w, resp, domain, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resId := f.getResIdForDomain(strings.TrimSuffix(domain, "."))
|
||||||
|
if resId != "" {
|
||||||
|
for _, ip := range ips {
|
||||||
|
var ipWithSuffix string
|
||||||
|
if ip.Is4() {
|
||||||
|
ipWithSuffix = ip.String() + "/32"
|
||||||
|
log.Tracef("resolved domain=%s to IPv4=%s", domain, ipWithSuffix)
|
||||||
|
} else {
|
||||||
|
ipWithSuffix = ip.String() + "/128"
|
||||||
|
log.Tracef("resolved domain=%s to IPv6=%s", domain, ipWithSuffix)
|
||||||
|
}
|
||||||
|
f.statusRecorder.AddResolvedIPLookupEntry(ipWithSuffix, resId)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
f.addIPsToResponse(resp, domain, ips)
|
||||||
|
|
||||||
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
|
log.Errorf("failed to write DNS response: %v", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// handleDNSError processes DNS lookup errors and sends an appropriate error response
|
||||||
|
func (f *DNSForwarder) handleDNSError(w dns.ResponseWriter, resp *dns.Msg, domain string, err error) {
|
||||||
var dnsErr *net.DNSError
|
var dnsErr *net.DNSError
|
||||||
|
|
||||||
switch {
|
switch {
|
||||||
@@ -105,15 +169,16 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
if err := w.WriteMsg(resp); err != nil {
|
if err := w.WriteMsg(resp); err != nil {
|
||||||
log.Errorf("failed to write failure DNS response: %v", err)
|
log.Errorf("failed to write failure DNS response: %v", err)
|
||||||
}
|
}
|
||||||
return
|
}
|
||||||
}
|
|
||||||
|
|
||||||
|
// addIPsToResponse adds IP addresses to the DNS response as appropriate A or AAAA records
|
||||||
|
func (f *DNSForwarder) addIPsToResponse(resp *dns.Msg, domain string, ips []netip.Addr) {
|
||||||
for _, ip := range ips {
|
for _, ip := range ips {
|
||||||
var respRecord dns.RR
|
var respRecord dns.RR
|
||||||
if ip.To4() == nil {
|
if ip.Is6() {
|
||||||
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
log.Tracef("resolved domain=%s to IPv6=%s", domain, ip)
|
||||||
rr := dns.AAAA{
|
rr := dns.AAAA{
|
||||||
AAAA: ip,
|
AAAA: ip.AsSlice(),
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: domain,
|
Name: domain,
|
||||||
Rrtype: dns.TypeAAAA,
|
Rrtype: dns.TypeAAAA,
|
||||||
@@ -125,7 +190,7 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
} else {
|
} else {
|
||||||
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
log.Tracef("resolved domain=%s to IPv4=%s", domain, ip)
|
||||||
rr := dns.A{
|
rr := dns.A{
|
||||||
A: ip,
|
A: ip.AsSlice(),
|
||||||
Hdr: dns.RR_Header{
|
Hdr: dns.RR_Header{
|
||||||
Name: domain,
|
Name: domain,
|
||||||
Rrtype: dns.TypeA,
|
Rrtype: dns.TypeA,
|
||||||
@@ -137,10 +202,36 @@ func (f *DNSForwarder) handleDNSQuery(w dns.ResponseWriter, query *dns.Msg) {
|
|||||||
}
|
}
|
||||||
resp.Answer = append(resp.Answer, respRecord)
|
resp.Answer = append(resp.Answer, respRecord)
|
||||||
}
|
}
|
||||||
|
}
|
||||||
|
|
||||||
if err := w.WriteMsg(resp); err != nil {
|
func (f *DNSForwarder) getResIdForDomain(domain string) string {
|
||||||
log.Errorf("failed to write DNS response: %v", err)
|
var selectedResId string
|
||||||
|
var bestScore int
|
||||||
|
|
||||||
|
f.resId.Range(func(key, value interface{}) bool {
|
||||||
|
var score int
|
||||||
|
pattern := key.(string)
|
||||||
|
|
||||||
|
switch {
|
||||||
|
case strings.HasPrefix(pattern, "*."):
|
||||||
|
baseDomain := strings.TrimPrefix(pattern, "*.")
|
||||||
|
if domain == baseDomain || strings.HasSuffix(domain, "."+baseDomain) {
|
||||||
|
score = len(baseDomain)
|
||||||
}
|
}
|
||||||
|
case domain == pattern:
|
||||||
|
score = math.MaxInt
|
||||||
|
default:
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if score > bestScore {
|
||||||
|
bestScore = score
|
||||||
|
selectedResId = value.(string)
|
||||||
|
}
|
||||||
|
return true
|
||||||
|
})
|
||||||
|
|
||||||
|
return selectedResId
|
||||||
}
|
}
|
||||||
|
|
||||||
// filterDomains returns a list of normalized domains
|
// filterDomains returns a list of normalized domains
|
||||||
|
|||||||
95
client/internal/dnsfwd/forwarder_test.go
Normal file
95
client/internal/dnsfwd/forwarder_test.go
Normal file
@@ -0,0 +1,95 @@
|
|||||||
|
package dnsfwd
|
||||||
|
|
||||||
|
import (
|
||||||
|
"sync"
|
||||||
|
"testing"
|
||||||
|
)
|
||||||
|
|
||||||
|
func TestGetResIdForDomain(t *testing.T) {
|
||||||
|
testCases := []struct {
|
||||||
|
name string
|
||||||
|
storedMappings map[string]string // key: domain pattern, value: resId
|
||||||
|
queryDomain string
|
||||||
|
expectedResId string
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "Empty map returns empty string",
|
||||||
|
storedMappings: map[string]string{},
|
||||||
|
queryDomain: "example.com",
|
||||||
|
expectedResId: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Exact match returns stored resId",
|
||||||
|
storedMappings: map[string]string{"example.com": "res1"},
|
||||||
|
queryDomain: "example.com",
|
||||||
|
expectedResId: "res1",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Wildcard pattern matches base domain",
|
||||||
|
storedMappings: map[string]string{"*.example.com": "res2"},
|
||||||
|
queryDomain: "example.com",
|
||||||
|
expectedResId: "res2",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Wildcard pattern matches subdomain",
|
||||||
|
storedMappings: map[string]string{"*.example.com": "res3"},
|
||||||
|
queryDomain: "foo.example.com",
|
||||||
|
expectedResId: "res3",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Wildcard pattern does not match different domain",
|
||||||
|
storedMappings: map[string]string{"*.example.com": "res4"},
|
||||||
|
queryDomain: "foo.notexample.com",
|
||||||
|
expectedResId: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Non-wildcard pattern does not match subdomain",
|
||||||
|
storedMappings: map[string]string{"example.com": "res5"},
|
||||||
|
queryDomain: "foo.example.com",
|
||||||
|
expectedResId: "",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Exact match over overlapping wildcard",
|
||||||
|
storedMappings: map[string]string{
|
||||||
|
"*.example.com": "resWildcard",
|
||||||
|
"foo.example.com": "resExact",
|
||||||
|
},
|
||||||
|
queryDomain: "foo.example.com",
|
||||||
|
expectedResId: "resExact",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Overlapping wildcards: Select more specific wildcard",
|
||||||
|
storedMappings: map[string]string{
|
||||||
|
"*.example.com": "resA",
|
||||||
|
"*.sub.example.com": "resB",
|
||||||
|
},
|
||||||
|
queryDomain: "bar.sub.example.com",
|
||||||
|
expectedResId: "resB",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "Wildcard multi-level subdomain match",
|
||||||
|
storedMappings: map[string]string{
|
||||||
|
"*.example.com": "resMulti",
|
||||||
|
},
|
||||||
|
queryDomain: "a.b.example.com",
|
||||||
|
expectedResId: "resMulti",
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tc := range testCases {
|
||||||
|
t.Run(tc.name, func(t *testing.T) {
|
||||||
|
fwd := &DNSForwarder{
|
||||||
|
resId: sync.Map{},
|
||||||
|
}
|
||||||
|
|
||||||
|
for domainPattern, resId := range tc.storedMappings {
|
||||||
|
fwd.resId.Store(domainPattern, resId)
|
||||||
|
}
|
||||||
|
|
||||||
|
got := fwd.getResIdForDomain(tc.queryDomain)
|
||||||
|
if got != tc.expectedResId {
|
||||||
|
t.Errorf("For query domain %q, expected resId %q, but got %q", tc.queryDomain, tc.expectedResId, got)
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
@@ -10,6 +10,7 @@ import (
|
|||||||
|
|
||||||
nberrors "github.com/netbirdio/netbird/client/errors"
|
nberrors "github.com/netbirdio/netbird/client/errors"
|
||||||
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
firewall "github.com/netbirdio/netbird/client/firewall/manager"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
)
|
)
|
||||||
|
|
||||||
const (
|
const (
|
||||||
@@ -20,18 +21,20 @@ const (
|
|||||||
|
|
||||||
type Manager struct {
|
type Manager struct {
|
||||||
firewall firewall.Manager
|
firewall firewall.Manager
|
||||||
|
statusRecorder *peer.Status
|
||||||
|
|
||||||
fwRules []firewall.Rule
|
fwRules []firewall.Rule
|
||||||
dnsForwarder *DNSForwarder
|
dnsForwarder *DNSForwarder
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewManager(fw firewall.Manager) *Manager {
|
func NewManager(fw firewall.Manager, statusRecorder *peer.Status) *Manager {
|
||||||
return &Manager{
|
return &Manager{
|
||||||
firewall: fw,
|
firewall: fw,
|
||||||
|
statusRecorder: statusRecorder,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Start(domains []string) error {
|
func (m *Manager) Start(domains []string, resIds map[string]string) error {
|
||||||
log.Infof("starting DNS forwarder")
|
log.Infof("starting DNS forwarder")
|
||||||
if m.dnsForwarder != nil {
|
if m.dnsForwarder != nil {
|
||||||
return nil
|
return nil
|
||||||
@@ -41,9 +44,9 @@ func (m *Manager) Start(domains []string) error {
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL)
|
m.dnsForwarder = NewDNSForwarder(fmt.Sprintf(":%d", ListenPort), dnsTTL, m.statusRecorder)
|
||||||
go func() {
|
go func() {
|
||||||
if err := m.dnsForwarder.Listen(domains); err != nil {
|
if err := m.dnsForwarder.Listen(domains, resIds); err != nil {
|
||||||
// todo handle close error if it is exists
|
// todo handle close error if it is exists
|
||||||
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
log.Errorf("failed to start DNS forwarder, err: %v", err)
|
||||||
}
|
}
|
||||||
@@ -52,12 +55,12 @@ func (m *Manager) Start(domains []string) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) UpdateDomains(domains []string) {
|
func (m *Manager) UpdateDomains(domains []string, resIds map[string]string) {
|
||||||
if m.dnsForwarder == nil {
|
if m.dnsForwarder == nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.dnsForwarder.UpdateDomains(domains)
|
m.dnsForwarder.UpdateDomains(domains, resIds)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (m *Manager) Stop(ctx context.Context) error {
|
func (m *Manager) Stop(ctx context.Context) error {
|
||||||
@@ -88,7 +91,7 @@ func (h *Manager) allowDNSFirewall() error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
dnsRules, err := h.firewall.AddPeerFiltering(net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "", "")
|
dnsRules, err := h.firewall.AddPeerFiltering(nil, net.IP{0, 0, 0, 0}, firewall.ProtocolUDP, nil, dport, firewall.ActionAccept, "")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
log.Errorf("failed to add allow DNS router rules, err: %v", err)
|
||||||
return err
|
return err
|
||||||
|
|||||||
@@ -34,6 +34,8 @@ import (
|
|||||||
"github.com/netbirdio/netbird/client/internal/dns"
|
"github.com/netbirdio/netbird/client/internal/dns"
|
||||||
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
"github.com/netbirdio/netbird/client/internal/dnsfwd"
|
||||||
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
"github.com/netbirdio/netbird/client/internal/ingressgw"
|
||||||
|
"github.com/netbirdio/netbird/client/internal/netflow"
|
||||||
|
nftypes "github.com/netbirdio/netbird/client/internal/netflow/types"
|
||||||
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
"github.com/netbirdio/netbird/client/internal/networkmonitor"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer"
|
"github.com/netbirdio/netbird/client/internal/peer"
|
||||||
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
"github.com/netbirdio/netbird/client/internal/peer/guard"
|
||||||
@@ -189,6 +191,7 @@ type Engine struct {
|
|||||||
persistNetworkMap bool
|
persistNetworkMap bool
|
||||||
latestNetworkMap *mgmProto.NetworkMap
|
latestNetworkMap *mgmProto.NetworkMap
|
||||||
connSemaphore *semaphoregroup.SemaphoreGroup
|
connSemaphore *semaphoregroup.SemaphoreGroup
|
||||||
|
flowManager nftypes.FlowManager
|
||||||
}
|
}
|
||||||
|
|
||||||
// Peer is an instance of the Connection Peer
|
// Peer is an instance of the Connection Peer
|
||||||
@@ -308,6 +311,12 @@ func (e *Engine) Stop() error {
|
|||||||
time.Sleep(500 * time.Millisecond)
|
time.Sleep(500 * time.Millisecond)
|
||||||
|
|
||||||
e.close()
|
e.close()
|
||||||
|
|
||||||
|
// stop flow manager after wg interface is gone
|
||||||
|
if e.flowManager != nil {
|
||||||
|
e.flowManager.Close()
|
||||||
|
}
|
||||||
|
|
||||||
log.Infof("stopped Netbird Engine")
|
log.Infof("stopped Netbird Engine")
|
||||||
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
ctx, cancel := context.WithTimeout(context.Background(), 3*time.Second)
|
||||||
@@ -342,6 +351,10 @@ func (e *Engine) Start() error {
|
|||||||
}
|
}
|
||||||
e.wgInterface = wgIface
|
e.wgInterface = wgIface
|
||||||
|
|
||||||
|
// start flow manager right after interface creation
|
||||||
|
publicKey := e.config.WgPrivateKey.PublicKey()
|
||||||
|
e.flowManager = netflow.NewManager(e.wgInterface, publicKey[:], e.statusRecorder)
|
||||||
|
|
||||||
if e.config.RosenpassEnabled {
|
if e.config.RosenpassEnabled {
|
||||||
log.Infof("rosenpass is enabled")
|
log.Infof("rosenpass is enabled")
|
||||||
if e.config.RosenpassPermissive {
|
if e.config.RosenpassPermissive {
|
||||||
@@ -448,7 +461,7 @@ func (e *Engine) createFirewall() error {
|
|||||||
}
|
}
|
||||||
|
|
||||||
var err error
|
var err error
|
||||||
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.config.DisableServerRoutes)
|
e.firewall, err = firewall.NewFirewall(e.wgInterface, e.stateManager, e.flowManager.GetLogger(), e.config.DisableServerRoutes)
|
||||||
if err != nil || e.firewall == nil {
|
if err != nil || e.firewall == nil {
|
||||||
log.Errorf("failed creating firewall manager: %s", err)
|
log.Errorf("failed creating firewall manager: %s", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -482,13 +495,13 @@ func (e *Engine) initFirewall() error {
|
|||||||
|
|
||||||
// this rule is static and will be torn down on engine down by the firewall manager
|
// this rule is static and will be torn down on engine down by the firewall manager
|
||||||
if _, err := e.firewall.AddPeerFiltering(
|
if _, err := e.firewall.AddPeerFiltering(
|
||||||
|
nil,
|
||||||
net.IP{0, 0, 0, 0},
|
net.IP{0, 0, 0, 0},
|
||||||
firewallManager.ProtocolUDP,
|
firewallManager.ProtocolUDP,
|
||||||
nil,
|
nil,
|
||||||
&port,
|
&port,
|
||||||
firewallManager.ActionAccept,
|
firewallManager.ActionAccept,
|
||||||
"",
|
"",
|
||||||
"",
|
|
||||||
); err != nil {
|
); err != nil {
|
||||||
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
log.Errorf("failed to allow rosenpass interface traffic: %v", err)
|
||||||
return nil
|
return nil
|
||||||
@@ -512,6 +525,7 @@ func (e *Engine) blockLanAccess() {
|
|||||||
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
v4 := netip.PrefixFrom(netip.IPv4Unspecified(), 0)
|
||||||
for _, network := range toBlock {
|
for _, network := range toBlock {
|
||||||
if _, err := e.firewall.AddRouteFiltering(
|
if _, err := e.firewall.AddRouteFiltering(
|
||||||
|
nil,
|
||||||
[]netip.Prefix{v4},
|
[]netip.Prefix{v4},
|
||||||
network,
|
network,
|
||||||
firewallManager.ProtocolALL,
|
firewallManager.ProtocolALL,
|
||||||
@@ -642,25 +656,14 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
stunTurn = append(stunTurn, e.TURNs...)
|
stunTurn = append(stunTurn, e.TURNs...)
|
||||||
e.stunTurn.Store(stunTurn)
|
e.stunTurn.Store(stunTurn)
|
||||||
|
|
||||||
relayMsg := wCfg.GetRelay()
|
err = e.handleRelayUpdate(wCfg.GetRelay())
|
||||||
if relayMsg != nil {
|
if err != nil {
|
||||||
// when we receive token we expect valid address list too
|
return err
|
||||||
c := &auth.Token{
|
|
||||||
Payload: relayMsg.GetTokenPayload(),
|
|
||||||
Signature: relayMsg.GetTokenSignature(),
|
|
||||||
}
|
|
||||||
if err := e.relayManager.UpdateToken(c); err != nil {
|
|
||||||
log.Errorf("failed to update relay token: %v", err)
|
|
||||||
return fmt.Errorf("update relay token: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
e.relayManager.UpdateServerURLs(relayMsg.Urls)
|
err = e.handleFlowUpdate(wCfg.GetFlow())
|
||||||
|
if err != nil {
|
||||||
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
return fmt.Errorf("handle the flow configuration: %w", err)
|
||||||
// We can ignore all errors because the guard will manage the reconnection retries.
|
|
||||||
_ = e.relayManager.Serve()
|
|
||||||
} else {
|
|
||||||
e.relayManager.UpdateServerURLs(nil)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// todo update signal
|
// todo update signal
|
||||||
@@ -691,6 +694,57 @@ func (e *Engine) handleSync(update *mgmProto.SyncResponse) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleRelayUpdate(update *mgmProto.RelayConfig) error {
|
||||||
|
if update != nil {
|
||||||
|
// when we receive token we expect valid address list too
|
||||||
|
c := &auth.Token{
|
||||||
|
Payload: update.GetTokenPayload(),
|
||||||
|
Signature: update.GetTokenSignature(),
|
||||||
|
}
|
||||||
|
if err := e.relayManager.UpdateToken(c); err != nil {
|
||||||
|
return fmt.Errorf("update relay token: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
e.relayManager.UpdateServerURLs(update.Urls)
|
||||||
|
|
||||||
|
// Just in case the agent started with an MGM server where the relay was disabled but was later enabled.
|
||||||
|
// We can ignore all errors because the guard will manage the reconnection retries.
|
||||||
|
_ = e.relayManager.Serve()
|
||||||
|
} else {
|
||||||
|
e.relayManager.UpdateServerURLs(nil)
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (e *Engine) handleFlowUpdate(config *mgmProto.FlowConfig) error {
|
||||||
|
if config == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
flowConfig, err := toFlowLoggerConfig(config)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return e.flowManager.Update(flowConfig)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toFlowLoggerConfig(config *mgmProto.FlowConfig) (*nftypes.FlowConfig, error) {
|
||||||
|
if config.GetInterval() == nil {
|
||||||
|
return nil, errors.New("flow interval is nil")
|
||||||
|
}
|
||||||
|
return &nftypes.FlowConfig{
|
||||||
|
Enabled: config.GetEnabled(),
|
||||||
|
Counters: config.GetCounters(),
|
||||||
|
URL: config.GetUrl(),
|
||||||
|
TokenPayload: config.GetTokenPayload(),
|
||||||
|
TokenSignature: config.GetTokenSignature(),
|
||||||
|
Interval: config.GetInterval().AsDuration(),
|
||||||
|
DNSCollection: config.GetDnsCollection(),
|
||||||
|
ExitNodeCollection: config.GetExitNodeCollection(),
|
||||||
|
}, nil
|
||||||
|
}
|
||||||
|
|
||||||
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
// updateChecksIfNew updates checks if there are changes and sync new meta with management
|
||||||
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
func (e *Engine) updateChecksIfNew(checks []*mgmProto.Checks) error {
|
||||||
// if checks are equal, we skip the update
|
// if checks are equal, we skip the update
|
||||||
@@ -898,11 +952,6 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Apply ACLs in the beginning to avoid security leaks
|
|
||||||
if e.acl != nil {
|
|
||||||
e.acl.ApplyFiltering(networkMap)
|
|
||||||
}
|
|
||||||
|
|
||||||
if e.firewall != nil {
|
if e.firewall != nil {
|
||||||
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
if localipfw, ok := e.firewall.(localIpUpdater); ok {
|
||||||
if err := localipfw.UpdateLocalIPs(); err != nil {
|
if err := localipfw.UpdateLocalIPs(); err != nil {
|
||||||
@@ -913,14 +962,19 @@ func (e *Engine) updateNetworkMap(networkMap *mgmProto.NetworkMap) error {
|
|||||||
|
|
||||||
// DNS forwarder
|
// DNS forwarder
|
||||||
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
dnsRouteFeatureFlag := toDNSFeatureFlag(networkMap)
|
||||||
dnsRouteDomains := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
dnsRouteDomains, resourceIds := toRouteDomains(e.config.WgPrivateKey.PublicKey().String(), networkMap.GetRoutes())
|
||||||
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains)
|
e.updateDNSForwarder(dnsRouteFeatureFlag, dnsRouteDomains, resourceIds)
|
||||||
|
|
||||||
routes := toRoutes(networkMap.GetRoutes())
|
routes := toRoutes(networkMap.GetRoutes())
|
||||||
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
if err := e.routeManager.UpdateRoutes(serial, routes, dnsRouteFeatureFlag); err != nil {
|
||||||
log.Errorf("failed to update clientRoutes, err: %v", err)
|
log.Errorf("failed to update clientRoutes, err: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// acls might need routing to be enabled, so we apply after routes
|
||||||
|
if e.acl != nil {
|
||||||
|
e.acl.ApplyFiltering(networkMap)
|
||||||
|
}
|
||||||
|
|
||||||
// Ingress forward rules
|
// Ingress forward rules
|
||||||
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
if err := e.updateForwardRules(networkMap.GetForwardingRules()); err != nil {
|
||||||
log.Errorf("failed to update forward rules, err: %v", err)
|
log.Errorf("failed to update forward rules, err: %v", err)
|
||||||
@@ -1025,21 +1079,29 @@ func toRoutes(protoRoutes []*mgmProto.Route) []*route.Route {
|
|||||||
return routes
|
return routes
|
||||||
}
|
}
|
||||||
|
|
||||||
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) []string {
|
func toRouteDomains(myPubKey string, protoRoutes []*mgmProto.Route) ([]string, map[string]string) {
|
||||||
if protoRoutes == nil {
|
if protoRoutes == nil {
|
||||||
protoRoutes = []*mgmProto.Route{}
|
protoRoutes = []*mgmProto.Route{}
|
||||||
}
|
}
|
||||||
|
|
||||||
var dnsRoutes []string
|
var dnsRoutes []string
|
||||||
|
resIds := make(map[string]string)
|
||||||
for _, protoRoute := range protoRoutes {
|
for _, protoRoute := range protoRoutes {
|
||||||
if len(protoRoute.Domains) == 0 {
|
if len(protoRoute.Domains) == 0 {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if protoRoute.Peer == myPubKey {
|
if protoRoute.Peer == myPubKey {
|
||||||
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
|
dnsRoutes = append(dnsRoutes, protoRoute.Domains...)
|
||||||
|
// resource ID is the first part of the ID
|
||||||
|
resId := strings.Split(protoRoute.ID, ":")
|
||||||
|
for _, domain := range protoRoute.Domains {
|
||||||
|
if len(resId) > 0 {
|
||||||
|
resIds[domain] = resId[0]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
return dnsRoutes
|
}
|
||||||
|
}
|
||||||
|
return dnsRoutes, resIds
|
||||||
}
|
}
|
||||||
|
|
||||||
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
func toDNSConfig(protoDNSConfig *mgmProto.DNSConfig, network *net.IPNet) nbdns.Config {
|
||||||
@@ -1169,26 +1231,6 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
|||||||
PreSharedKey: e.config.PreSharedKey,
|
PreSharedKey: e.config.PreSharedKey,
|
||||||
}
|
}
|
||||||
|
|
||||||
if e.config.RosenpassEnabled && !e.config.RosenpassPermissive {
|
|
||||||
lk := []byte(e.config.WgPrivateKey.PublicKey().String())
|
|
||||||
rk := []byte(wgConfig.RemoteKey)
|
|
||||||
var keyInput []byte
|
|
||||||
if string(lk) > string(rk) {
|
|
||||||
//nolint:gocritic
|
|
||||||
keyInput = append(lk[:16], rk[:16]...)
|
|
||||||
} else {
|
|
||||||
//nolint:gocritic
|
|
||||||
keyInput = append(rk[:16], lk[:16]...)
|
|
||||||
}
|
|
||||||
|
|
||||||
key, err := wgtypes.NewKey(keyInput)
|
|
||||||
if err != nil {
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
|
|
||||||
wgConfig.PreSharedKey = &key
|
|
||||||
}
|
|
||||||
|
|
||||||
// randomize connection timeout
|
// randomize connection timeout
|
||||||
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
timeout := time.Duration(rand.Intn(PeerConnectionTimeoutMax-PeerConnectionTimeoutMin)+PeerConnectionTimeoutMin) * time.Millisecond
|
||||||
config := peer.ConnConfig{
|
config := peer.ConnConfig{
|
||||||
@@ -1197,8 +1239,11 @@ func (e *Engine) createPeerConn(pubKey string, allowedIPs []netip.Prefix) (*peer
|
|||||||
Timeout: timeout,
|
Timeout: timeout,
|
||||||
WgConfig: wgConfig,
|
WgConfig: wgConfig,
|
||||||
LocalWgPort: e.config.WgPort,
|
LocalWgPort: e.config.WgPort,
|
||||||
RosenpassPubKey: e.getRosenpassPubKey(),
|
RosenpassConfig: peer.RosenpassConfig{
|
||||||
RosenpassAddr: e.getRosenpassAddr(),
|
PubKey: e.getRosenpassPubKey(),
|
||||||
|
Addr: e.getRosenpassAddr(),
|
||||||
|
PermissiveMode: e.config.RosenpassPermissive,
|
||||||
|
},
|
||||||
ICEConfig: icemaker.Config{
|
ICEConfig: icemaker.Config{
|
||||||
StunTurn: &e.stunTurn,
|
StunTurn: &e.stunTurn,
|
||||||
InterfaceBlackList: e.config.IFaceBlackList,
|
InterfaceBlackList: e.config.IFaceBlackList,
|
||||||
@@ -1706,7 +1751,7 @@ func (e *Engine) GetWgAddr() net.IP {
|
|||||||
}
|
}
|
||||||
|
|
||||||
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
// updateDNSForwarder start or stop the DNS forwarder based on the domains and the feature flag
|
||||||
func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
func (e *Engine) updateDNSForwarder(enabled bool, domains []string, resIds map[string]string) {
|
||||||
if !enabled {
|
if !enabled {
|
||||||
if e.dnsForwardMgr == nil {
|
if e.dnsForwardMgr == nil {
|
||||||
return
|
return
|
||||||
@@ -1720,15 +1765,15 @@ func (e *Engine) updateDNSForwarder(enabled bool, domains []string) {
|
|||||||
if len(domains) > 0 {
|
if len(domains) > 0 {
|
||||||
log.Infof("enable domain router service for domains: %v", domains)
|
log.Infof("enable domain router service for domains: %v", domains)
|
||||||
if e.dnsForwardMgr == nil {
|
if e.dnsForwardMgr == nil {
|
||||||
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall)
|
e.dnsForwardMgr = dnsfwd.NewManager(e.firewall, e.statusRecorder)
|
||||||
|
|
||||||
if err := e.dnsForwardMgr.Start(domains); err != nil {
|
if err := e.dnsForwardMgr.Start(domains, resIds); err != nil {
|
||||||
log.Errorf("failed to start DNS forward: %v", err)
|
log.Errorf("failed to start DNS forward: %v", err)
|
||||||
e.dnsForwardMgr = nil
|
e.dnsForwardMgr = nil
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
log.Infof("update domain router service for domains: %v", domains)
|
log.Infof("update domain router service for domains: %v", domains)
|
||||||
e.dnsForwardMgr.UpdateDomains(domains)
|
e.dnsForwardMgr.UpdateDomains(domains, resIds)
|
||||||
}
|
}
|
||||||
} else if e.dnsForwardMgr != nil {
|
} else if e.dnsForwardMgr != nil {
|
||||||
log.Infof("disable domain router service")
|
log.Infof("disable domain router service")
|
||||||
|
|||||||
@@ -12,6 +12,7 @@ import (
|
|||||||
"testing"
|
"testing"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/golang/mock/gomock"
|
||||||
"github.com/google/uuid"
|
"github.com/google/uuid"
|
||||||
"github.com/pion/transport/v3/stdnet"
|
"github.com/pion/transport/v3/stdnet"
|
||||||
log "github.com/sirupsen/logrus"
|
log "github.com/sirupsen/logrus"
|
||||||
@@ -27,6 +28,8 @@ import (
|
|||||||
|
|
||||||
"github.com/netbirdio/management-integrations/integrations"
|
"github.com/netbirdio/management-integrations/integrations"
|
||||||
|
|
||||||
|
"github.com/netbirdio/netbird/management/server/types"
|
||||||
|
|
||||||
"github.com/netbirdio/netbird/client/iface"
|
"github.com/netbirdio/netbird/client/iface"
|
||||||
"github.com/netbirdio/netbird/client/iface/bind"
|
"github.com/netbirdio/netbird/client/iface/bind"
|
||||||
"github.com/netbirdio/netbird/client/iface/configurer"
|
"github.com/netbirdio/netbird/client/iface/configurer"
|
||||||
@@ -46,6 +49,7 @@ import (
|
|||||||
"github.com/netbirdio/netbird/management/server"
|
"github.com/netbirdio/netbird/management/server"
|
||||||
"github.com/netbirdio/netbird/management/server/activity"
|
"github.com/netbirdio/netbird/management/server/activity"
|
||||||
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
"github.com/netbirdio/netbird/management/server/integrations/port_forwarding"
|
||||||
|
"github.com/netbirdio/netbird/management/server/permissions"
|
||||||
"github.com/netbirdio/netbird/management/server/settings"
|
"github.com/netbirdio/netbird/management/server/settings"
|
||||||
"github.com/netbirdio/netbird/management/server/store"
|
"github.com/netbirdio/netbird/management/server/store"
|
||||||
"github.com/netbirdio/netbird/management/server/telemetry"
|
"github.com/netbirdio/netbird/management/server/telemetry"
|
||||||
@@ -1397,15 +1401,15 @@ func startSignal(t *testing.T) (*grpc.Server, string, error) {
|
|||||||
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, string, error) {
|
||||||
t.Helper()
|
t.Helper()
|
||||||
|
|
||||||
config := &server.Config{
|
config := &types.Config{
|
||||||
Stuns: []*server.Host{},
|
Stuns: []*types.Host{},
|
||||||
TURNConfig: &server.TURNConfig{},
|
TURNConfig: &types.TURNConfig{},
|
||||||
Relay: &server.Relay{
|
Relay: &types.Relay{
|
||||||
Addresses: []string{"127.0.0.1:1234"},
|
Addresses: []string{"127.0.0.1:1234"},
|
||||||
CredentialsTTL: util.Duration{Duration: time.Hour},
|
CredentialsTTL: util.Duration{Duration: time.Hour},
|
||||||
Secret: "222222222222222222",
|
Secret: "222222222222222222",
|
||||||
},
|
},
|
||||||
Signal: &server.Host{
|
Signal: &types.Host{
|
||||||
Proto: "http",
|
Proto: "http",
|
||||||
URI: "localhost:10000",
|
URI: "localhost:10000",
|
||||||
},
|
},
|
||||||
@@ -1435,13 +1439,23 @@ func startManagement(t *testing.T, dataDir, testFile string) (*grpc.Server, stri
|
|||||||
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
metrics, err := telemetry.NewDefaultAppMetrics(context.Background())
|
||||||
require.NoError(t, err)
|
require.NoError(t, err)
|
||||||
|
|
||||||
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock())
|
ctrl := gomock.NewController(t)
|
||||||
|
t.Cleanup(ctrl.Finish)
|
||||||
|
settingsMockManager := settings.NewMockManager(ctrl)
|
||||||
|
settingsMockManager.EXPECT().
|
||||||
|
GetSettings(gomock.Any(), gomock.Any(), gomock.Any()).
|
||||||
|
Return(&types.Settings{}, nil).
|
||||||
|
AnyTimes()
|
||||||
|
|
||||||
|
permissionsManager := permissions.NewManager(store)
|
||||||
|
|
||||||
|
accountManager, err := server.BuildManager(context.Background(), store, peersUpdateManager, nil, "", "netbird.selfhosted", eventStore, nil, false, ia, metrics, port_forwarding.NewControllerMock(), settingsMockManager, permissionsManager)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay)
|
secretsManager := server.NewTimeBasedAuthSecretsManager(peersUpdateManager, config.TURNConfig, config.Relay, settingsMockManager)
|
||||||
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settings.NewManager(store), peersUpdateManager, secretsManager, nil, nil, nil)
|
mgmtServer, err := server.NewServer(context.Background(), config, accountManager, settingsMockManager, peersUpdateManager, secretsManager, nil, nil, nil)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, "", err
|
return nil, "", err
|
||||||
}
|
}
|
||||||
|
|||||||
Some files were not shown because too many files have changed in this diff Show More
Reference in New Issue
Block a user